Lecture 8. Neural Networks

How to train your neurons

Joaquin Vanschoren

# Note: You'll need to install tensorflow-addons. One of these should work
# !pip install tensorflow_addons
# !pip install tfa-nightly

# TODO: Fix issues running Cyclical Learning rate and AdaMax with latest TF
# Auto-setup when running on Google Colab
import os
if 'google.colab' in str(get_ipython()) and not os.path.exists('/content/master'):
    !git clone -q https://github.com/ML-course/master.git /content/master
    !pip install -rq master/requirements_colab.txt
    %cd master/notebooks

# Global imports and settings
%matplotlib inline
from preamble import *
interactive = False # Set to True for interactive plots 
if interactive:
    fig_scale = 0.7
    plt.rcParams.update(print_config)
else: # For printing
    fig_scale = 0.8
    plt.rcParams.update(print_config)

Overview

  • Neural architectures

  • Training neural nets

    • Forward pass: Tensor operations

    • Backward pass: Backpropagation

  • Neural network design:

    • Activation functions

    • Weight initialization

    • Optimizers

  • Neural networks in practice

  • Model selection

    • Early stopping

    • Memorization capacity and information bottleneck

    • L1/L2 regularization

    • Dropout

    • Batch normalization

def draw_neural_net(ax, layer_sizes, draw_bias=False, labels=False, activation=False, sigmoid=False,
                    weight_count=False, random_weights=False, show_activations=False, figsize=(4, 4)):
    """
    Draws a dense neural net for educational purposes
    Parameters:
        ax: plot axis
        layer_sizes: array with the sizes of every layer
        draw_bias: whether to draw bias nodes
        labels: whether to draw labels for the weights and nodes
        activation: whether to show the activation function inside the nodes
        sigmoid: whether the last activation function is a sigmoid
        weight_count: whether to show the number of weights and biases
        random_weights: whether to show random weights as colored lines
        show_activations: whether to show a variable for the node activations
        scale_ratio: ratio of the plot dimensions, e.g. 3/4
    """
    left, right, bottom, top = 0.1, 0.89*figsize[0]/figsize[1], 0.1, 0.89
    n_layers = len(layer_sizes)
    v_spacing = (top - bottom)/float(max(layer_sizes))
    h_spacing = (right - left)/float(len(layer_sizes) - 1)
    colors = ['greenyellow','cornflowerblue','lightcoral']
    w_count, b_count = 0, 0
    ax.set_xlim(0, figsize[0]/figsize[1])
    ax.axis('off')
    ax.set_aspect('equal')
    txtargs = {"fontsize":12, "verticalalignment":'center', "horizontalalignment":'center', "zorder":5}
    
    # Draw biases by adding a node to every layer except the last one
    if draw_bias:
        layer_sizes = [x+1 for x in layer_sizes]
        layer_sizes[-1] = layer_sizes[-1] - 1
        
    # Nodes
    for n, layer_size in enumerate(layer_sizes):
        layer_top = v_spacing*(layer_size - 1)/2. + (top + bottom)/2. 
        node_size = v_spacing/len(layer_sizes) if activation and n!=0 else v_spacing/3.
        if n==0:
            color = colors[0]
        elif n==len(layer_sizes)-1:
            color = colors[2]
        else:
            color = colors[1]
        for m in range(layer_size):
            ax.add_artist(plt.Circle((n*h_spacing + left, layer_top - m*v_spacing), radius=node_size,
                                      color=color, ec='k', zorder=4))
            b_count += 1
            nx, ny = n*h_spacing + left, layer_top - m*v_spacing
            nsx, nsy = [n*h_spacing + left,n*h_spacing + left], [layer_top - m*v_spacing - 0.5*node_size*2,layer_top - m*v_spacing + 0.5*node_size*2]
            if draw_bias and m==0 and n<len(layer_sizes)-1:
                ax.text(nx, ny, r'$1$', **txtargs)
            elif labels and n==0:
                ax.text(n*h_spacing + left,layer_top + v_spacing/1.5, 'input', **txtargs)
                ax.text(nx, ny, r'$x_{}$'.format(m), **txtargs)
            elif labels and n==len(layer_sizes)-1:
                if activation:
                    if sigmoid:
                        ax.text(n*h_spacing + left,layer_top - m*v_spacing, r"$z \;\;\; \sigma$", **txtargs)
                    else:
                        ax.text(n*h_spacing + left,layer_top - m*v_spacing, r"$z_{} \;\; g$".format(m), **txtargs)
                    ax.add_artist(plt.Line2D(nsx, nsy, c='k', zorder=6))
                    if show_activations:        
                        ax.text(n*h_spacing + left + 1.5*node_size,layer_top - m*v_spacing, r"$\hat{y}$", fontsize=12, 
                                verticalalignment='center', horizontalalignment='left', zorder=5, c='r')

                else:
                    ax.text(nx, ny, r'$o_{}$'.format(m), **txtargs)
                ax.text(n*h_spacing + left,layer_top + v_spacing, 'output', **txtargs)
            elif labels:
                if activation:
                    ax.text(n*h_spacing + left,layer_top - m*v_spacing, r"$z_{} \;\; f$".format(m), **txtargs)
                    ax.add_artist(plt.Line2D(nsx, nsy, c='k', zorder=6))
                    if show_activations:        
                        ax.text(n*h_spacing + left + node_size,layer_top - m*v_spacing, r"$a_{}$".format(m), fontsize=12, 
                                verticalalignment='center', horizontalalignment='left', zorder=5, c='b')
                else:
                    ax.text(nx, ny, r'$h_{}$'.format(m), **txtargs)
                
            
    # Edges
    for n, (layer_size_a, layer_size_b) in enumerate(zip(layer_sizes[:-1], layer_sizes[1:])):
        layer_top_a = v_spacing*(layer_size_a - 1)/2. + (top + bottom)/2.
        layer_top_b = v_spacing*(layer_size_b - 1)/2. + (top + bottom)/2.
        for m in range(layer_size_a):
            for o in range(layer_size_b):
                if not (draw_bias and o==0 and len(layer_sizes)>2 and n<layer_size_b-1):
                    xs = [n*h_spacing + left, (n + 1)*h_spacing + left]
                    ys = [layer_top_a - m*v_spacing, layer_top_b - o*v_spacing]
                    color = 'k' if not random_weights else plt.cm.bwr(np.random.random())
                    ax.add_artist(plt.Line2D(xs, ys, c=color, lw=1, alpha=0.6))
                    if not (draw_bias and m==0):
                        w_count += 1
                    if labels and not random_weights:
                        wl = r'$w_{{{},{}}}$'.format(m,o) if layer_size_b>1 else r'$w_{}$'.format(m)
                        ax.text(xs[0]+np.diff(xs)/2, np.mean(ys)-np.diff(ys)/9, wl, ha='center', va='center', 
                                 fontsize=10)
    # Count
    if weight_count:
        b_count = b_count - layer_sizes[0]
        if draw_bias:
            b_count = b_count - (len(layer_sizes) - 2)
        ax.text(right, bottom, "{} weights, {} biases".format(w_count, b_count), ha='center', va='center')

Linear models as a building block

  • Logistic regression, drawn in a different, neuro-inspired, way

    • Linear model: inner product (\(z\)) of input vector \(\mathbf{x}\) and weight vector \(\mathbf{w}\), plus bias \(w_0\)

    • Logistic (or sigmoid) function maps the output to a probability in [0,1]

    • Uses log loss (cross-entropy) and gradient descent to learn the weights

\[\hat{y}(\mathbf{x}) = \text{sigmoid}(z) = \text{sigmoid}(w_0 + \mathbf{w}\mathbf{x}) = \text{sigmoid}(w_0 + w_1 * x_1 + w_2 * x_2 +... + w_p * x_p)\]
fig = plt.figure(figsize=(3*fig_scale, 3*fig_scale))
ax = fig.gca()
draw_neural_net(ax, [4, 1], activation=True, draw_bias=True, labels=True, sigmoid=True)
../_images/08 - Neural Networks_6_0.png

Basic Architecture

  • Add one (or more) hidden layers \(h\) with \(k\) nodes (or units, cells, neurons)

    • Every ‘neuron’ is a tiny function, the network is an arbitrarily complex function

    • Weights \(w_{i,j}\) between node \(i\) and node \(j\) form a weight matrix \(\mathbf{W}^{(l)}\) per layer \(l\)

  • Every neuron weights the inputs \(\mathbf{x}\) and passes it through a non-linear activation function

    • Activation functions (\(f,g\)) can be different per layer, output \(\mathbf{a}\) is called activation $\(\color{blue}{h(\mathbf{x})} = \color{blue}{\mathbf{a}} = f(\mathbf{z}) = f(\mathbf{W}^{(1)} \color{green}{\mathbf{x}}+\mathbf{w}^{(1)}_0) \quad \quad \color{red}{o(\mathbf{x})} = g(\mathbf{W}^{(2)} \color{blue}{\mathbf{a}}+\mathbf{w}^{(2)}_0)\)$

fig, axes = plt.subplots(1,2, figsize=(8, 4))
draw_neural_net(axes[0], [2, 3, 1],  draw_bias=True, labels=True, weight_count=True)
draw_neural_net(axes[1], [2, 3, 1],  activation=True, show_activations=True, draw_bias=True, labels=True, weight_count=True)
../_images/08 - Neural Networks_8_0.png

More layers

  • Add more layers, and more nodes per layer, to make the model more complex

    • For simplicity, we don’t draw the biases (but remember that they are there)

  • In dense (fully-connected) layers, every previous layer node is connected to all nodes

  • The output layer can also have multiple nodes (e.g. 1 per class in multi-class classification)

import ipywidgets as widgets
from ipywidgets import interact, interact_manual

@interact
def plot_dense_net(nr_layers=(0,6,1), nr_nodes=(1,12,1)):
    fig = plt.figure(figsize=(6, 4))
    ax = fig.gca()
    ax.axis('off')
    hidden = [nr_nodes]*nr_layers
    draw_neural_net(ax, [5] + hidden + [5], weight_count=True, figsize=(6, 4));
if not interactive:
    plot_dense_net(nr_layers=6, nr_nodes=10)
../_images/08 - Neural Networks_11_0.png

Why layers?

  • Each layer acts as a filter and learns a new representation of the data

    • Subsequent layers can learn iterative refinements

    • Easier that learning a complex relationship in one go

  • Example: for image input, each layer yields new (filtered) images

    • Can learn multiple mappings at once: weight tensor \(\mathit{W}\) yields activation tensor \(\mathit{A}\)

    • From low-level patterns (edges, end-points, …) to combinations thereof

    • Each neuron ‘lights up’ if certain patterns occur in the input

ml

Other architectures

  • There exist MANY types of networks for many different tasks

  • Convolutional nets for image data, Recurrent nets for sequential data,…

  • Also used to learn representations (embeddings), generate new images, text,…

ml

Training Neural Nets

  • Design the architecture, choose activation functions (e.g. sigmoids)

  • Choose a way to initialize the weights (e.g. random initialization)

  • Choose a loss function (e.g. log loss) to measure how well the model fits training data

  • Choose an optimizer (typically an SGD variant) to update the weights

ml

Mini-batch Stochastic Gradient Descent (recap)

  1. Draw a batch of batch_size training data \(\mathbf{X}\) and \(\mathbf{y}\)

  2. Forward pass : pass \(\mathbf{X}\) though the network to yield predictions \(\mathbf{\hat{y}}\)

  3. Compute the loss \(\mathcal{L}\) (mismatch between \(\mathbf{\hat{y}}\) and \(\mathbf{y}\))

  4. Backward pass : Compute the gradient of the loss with regard to every weight

    • Backpropagate the gradients through all the layers

  5. Update \(W\): \(W_{(i+1)} = W_{(i)} - \frac{\partial L(x, W_{(i)})}{\partial W} * \eta\)

Repeat until n passes (epochs) are made through the entire training set

# TODO: show the actual weight updates
@interact
def draw_updates(iteration=(1,100,1)):
    fig, ax = plt.subplots(1, 1, figsize=(6, 4))
    np.random.seed(iteration)
    draw_neural_net(ax, [6,5,5,3], labels=True, random_weights=True, show_activations=True, figsize=(6, 4));
if not interactive:
    draw_updates(iteration=1)
../_images/08 - Neural Networks_17_0.png

Forward pass

  • We can naturally represent the data as tensors

    • Numerical n-dimensional array (with n axes)

    • 2D tensor: matrix (samples, features)

    • 3D tensor: time series (samples, timesteps, features)

    • 4D tensor: color images (samples, height, width, channels)

    • 5D tensor: video (samples, frames, height, width, channels)

ml ml

Tensor operations

  • The operations that the network performs on the data can be reduced to a series of tensor operations

    • These are also much easier to run on GPUs

  • A dense layer with sigmoid activation, input tensor \(\mathbf{X}\), weight tensor \(\mathbf{W}\), bias \(\mathbf{b}\):

y = sigmoid(np.dot(X, W) + b)
  • Tensor dot product for 2D inputs (\(a\) samples, \(b\) features, \(c\) hidden nodes)

ml

Element-wise operations

  • Activation functions and addition are element-wise operations:

def sigmoid(x):
  return 1/(1 + np.exp(-x)) 

def add(x, y):
  return x + y
  • Note: if y has a lower dimension than x, it will be broadcasted: axes are added to match the dimensionality, and y is repeated along the new axes

>>> np.array([[1,2],[3,4]]) + np.array([10,20])
array([[11, 22],
       [13, 24]])

Backward pass (backpropagation)

  • For last layer, compute gradient of the loss function \(\mathcal{L}\) w.r.t all weights of layer \(l\)

\[\begin{split}\nabla \mathcal{L} = \frac{\partial \mathcal{L}}{\partial W^{(l)}} = \begin{bmatrix} \frac{\partial \mathcal{L}}{\partial w_{0,0}} & \ldots & \frac{\partial \mathcal{L}}{\partial w_{0,l}} \\ \vdots & \ddots & \vdots \\ \frac{\partial \mathcal{L}}{\partial w_{k,0}} & \ldots & \frac{\partial \mathcal{L}}{\partial w_{k,l}} \end{bmatrix} \\[15pt]\end{split}\]
  • Sum up the gradients for all \(\mathbf{x}_j\) in minibatch: \(\sum_{j} \frac{\partial \mathcal{L}(\mathbf{x}_j,y_j)}{\partial W^{(l)}}\)

  • Update all weights in a layer at once (with learning rate \(\eta\)): \(W_{(i+1)}^{(l)} = W_{(i)}^{(l)} - \eta \sum_{j} \frac{\partial \mathcal{L}(\mathbf{x}_j,y_j)}{\partial W_{(i)}^{(l)}}\)

  • Repeat for next layer, iterating backwards (most efficient, avoids redundant calculations)

ml

Backpropagation (example)

  • Imagine feeding a single data point, output is \(\hat{y} = g(z) = g(w_0 + w_1 * a_1 + w_2 * a_2 +... + w_p * a_p)\)

  • Decrease loss by updating weights:

    • Update the weights of last layer to maximize improvement: \(w_{i,(new)} = w_{i} - \frac{\partial \mathcal{L}}{\partial w_i} * \eta\)

    • To compute gradient \(\frac{\partial \mathcal{L}}{\partial w_i}\) we need the chain rule: \(f(g(x)) = f'(g(x)) * g'(x)\) $\(\frac{\partial \mathcal{L}}{\partial w_i} = \color{red}{\frac{\partial \mathcal{L}}{\partial g}} \color{blue}{\frac{\partial \mathcal{g}}{\partial z_0}} \color{green}{\frac{\partial \mathcal{z_0}}{\partial w_i}}\)$

  • E.g., with \(\mathcal{L} = \frac{1}{2}(y-\hat{y})^2\) and sigmoid \(\sigma\): \(\frac{\partial \mathcal{L}}{\partial w_i} = \color{red}{(y - \hat{y})} * \color{blue}{\sigma'(z_0)} * \color{green}{a_i}\)

fig = plt.figure(figsize=(4, 3.5))
ax = fig.gca()
draw_neural_net(ax, [2, 3, 1],  activation=True, draw_bias=True, labels=True, 
                show_activations=True)
../_images/08 - Neural Networks_23_0.png

Backpropagation (2)

  • Another way to decrease the loss \(\mathcal{L}\) is to update the activations \(a_i\)

    • To update \(a_i = f(z_i)\), we need to update the weights of the previous layer

    • We want to nudge \(a_i\) in the right direction by updating \(w_{i,j}\): $\(\frac{\partial \mathcal{L}}{\partial w_{i,j}} = \frac{\partial \mathcal{L}}{\partial a_i} \frac{\partial a_i}{\partial z_i} \frac{\partial \mathcal{z_i}}{\partial w_{i,j}} = \left( \frac{\partial \mathcal{L}}{\partial g} \frac{\partial \mathcal{g}}{\partial z_0} \frac{\partial \mathcal{z_0}}{\partial a_i} \right) \frac{\partial a_i}{\partial z_i} \frac{\partial \mathcal{z_i}}{\partial w_{i,j}}\)$

    • We know \(\frac{\partial \mathcal{L}}{\partial g}\) and \(\frac{\partial \mathcal{g}}{\partial z_0}\) from the previous step, \(\frac{\partial \mathcal{z_0}}{\partial a_i} = w_i\), \(\frac{\partial a_i}{\partial z_i} = f'\) and \(\frac{\partial \mathcal{z_i}}{\partial w_{i,j}} = x_j\)

fig = plt.figure(figsize=(4, 4))
ax = fig.gca()
draw_neural_net(ax, [2, 3, 1],  activation=True, draw_bias=True, labels=True, 
                show_activations=True)
../_images/08 - Neural Networks_25_0.png

Backpropagation (3)

  • With multiple output nodes, \(\mathcal{L}\) is the sum of all per-output (per-class) losses

    • \(\frac{\partial \mathcal{L}}{\partial a_i}\) is sum of the gradients for every output

  • Per layer, sum up gradients for every point \(\mathbf{x}\) in the batch: \(\sum_{j} \frac{\partial \mathcal{L}(\mathbf{x}_j,y_j)}{\partial W}\)

  • Update all weights of every layer \(l\)

    • \(W_{(i+1)}^{(l)} = W_{(i)}^{(l)} - \eta \sum_{j} \frac{\partial \mathcal{L}(\mathbf{x}_j,y_j)}{\partial W_{(i)}^{(l)}}\)

  • Repeat with a new batch of data until loss converges

  • Nice animation of the entire process

fig = plt.figure(figsize=(8, 4))
ax = fig.gca()
draw_neural_net(ax, [2, 3, 3, 2],  activation=True, draw_bias=True, labels=True, 
                random_weights=True, show_activations=True, figsize=(8, 4))
../_images/08 - Neural Networks_27_0.png

Backpropagation (summary)

  • The network output \(a_o\) is defined by the weights \(W^{(o)}\) and biases \(\mathbf{b}^{(o)}\) of the output layer, and

  • The activations of a hidden layer \(h_1\) with activation function \(a_{h_1}\), weights \(W^{(1)}\) and biases \(\mathbf{b^{(1)}}\):

\[\color{red}{a_o(\mathbf{x})} = \color{red}{a_o(\mathbf{z_0})} = \color{red}{a_o(W^{(o)}} \color{blue}{a_{h_1}(z_{h_1})} \color{red}{+ \mathbf{b}^{(o)})} = \color{red}{a_o(W^{(o)}} \color{blue}{a_{h_1}(W^{(1)} \color{green}{\mathbf{x}} + \mathbf{b}^{(1)})} \color{red}{+ \mathbf{b}^{(o)})} \]
  • Minimize the loss by SGD. For layer \(l\), compute \(\frac{\partial \mathcal{L}(a_o(x))}{\partial W_l}\) and \(\frac{\partial \mathcal{L}(a_o(x))}{\partial b_{l,i}}\) using the chain rule

  • Decomposes into gradient of layer above, gradient of activation function, gradient of layer input:

\[\frac{\partial \mathcal{L}(a_o)}{\partial W^{(1)}} = \color{red}{\frac{\partial \mathcal{L}(a_o)}{\partial a_{h_1}}} \color{blue}{\frac{\partial a_{h_1}}{\partial z_{h_1}}} \color{green}{\frac{\partial z_{h_1}}{\partial W^{(1)}}} = \left( \color{red}{\frac{\partial \mathcal{L}(a_o)}{\partial a_o}} \color{blue}{\frac{\partial a_o}{\partial z_o}} \color{green}{\frac{\partial z_o}{\partial a_{h_1}}}\right) \color{blue}{\frac{\partial a_{h_1}}{\partial z_{h_1}}} \color{green}{\frac{\partial z_{h_1}}{\partial W^{(1)}}} \]
ml

Activation functions for hidden layers

  • Sigmoid: \(f(z) = \frac{1}{1+e^{-z}}\)

  • Tanh: \(f(z) = \frac{2}{1+e^{-2z}} - 1\)

    • Activations around 0 are better for gradient descent convergence

  • Rectified Linear (ReLU): \(f(z) = max(0,z)\)

    • Less smooth, but much faster (note: not differentiable at 0)

  • Leaky ReLU: \(f(z) = \begin{cases} 0.01z & z<0 \\ z & otherwise \end{cases}\)

def activation(X, function="sigmoid"):     
    if function == "sigmoid":      
        return 1.0/(1.0 + np.exp(-X))    
    if function == "softmax": 
        return np.exp(X) / np.sum(np.exp(X), axis=0)   
    elif function == "tanh":      
        return np.tanh(X)    
    elif function == "relu":      
        return np.maximum(0,X)    
    elif function == "leaky_relu":      
        return np.maximum(0.1*X,X)
    elif function == "none":      
        return X
    
def activation_derivative(X, function="sigmoid"):   
    if function == "sigmoid": 
        sig = 1.0/(1.0 + np.exp(-X))   
        return sig * (1 - sig)
    elif function == "tanh":      
        return 1 - np.tanh(X)**2   
    elif function == "relu":      
        return np.where(X > 0, 1, 0)
    elif function == "leaky_relu":    
        # Using 0.1 instead of 0.01 to make it visible in the plot
        return np.where(X > 0, 1, 0.1)
    elif function == "none":      
        return X/X
    
def plot_activation(function, ax, derivative=False):
    if function=="softmax":       
        x = np.linspace(-6,6,9)
        ax.plot(x,activation(x, function),lw=2, c='b', linestyle='-', marker='o')
    else:     
        x = np.linspace(-6,6,101)
        ax.plot(x,activation(x, function),lw=2, c='b', linestyle='-') 
        if derivative:
            if function == "relu" or function == "leaky_relu":
                ax.step(x,activation_derivative(x, function),lw=2, c='r', linestyle='-')
            else:
                ax.plot(x,activation_derivative(x, function),lw=2, c='r', linestyle='-')
    ax.set_xlabel("input")
    ax.set_ylabel(function)
    ax.grid()
    
functions = ["sigmoid","tanh","relu","leaky_relu"]

@interact
def plot_activations(function=functions):
    fig, ax = plt.subplots(figsize=(6,2))
    plot_activation(function, ax)
if not interactive:
    fig, axes = plt.subplots(1,4, figsize=(10,2))
    for function, ax in zip(functions,axes):
        plot_activation(function, ax)
    plt.tight_layout();
../_images/08 - Neural Networks_31_0.png

Effect of activation functions on the gradient

  • During gradient descent, the gradient depends on the activation function \(a_{h}\): \(\frac{\partial \mathcal{L}(a_o)}{\partial W^{(l)}} = \color{red}{\frac{\partial \mathcal{L}(a_o)}{\partial a_{h_l}}} \color{blue}{\frac{\partial a_{h_l}}{\partial z_{h_l}}} \color{green}{\frac{\partial z_{h_l}}{\partial W^{(l)}}}\)

  • If derivative of the activation function \(\color{blue}{\frac{\partial a_{h_l}}{\partial z_{h_l}}}\) is 0, the weights \(w_i\) are not updated

    • Moreover, the gradients of previous layers will be reduced (vanishing gradient)

  • sigmoid, tanh: gradient is very small for large inputs: slow updates

  • With ReLU, \(\color{blue}{\frac{\partial a_{h_l}}{\partial z_{h_l}}} = 1\) if \(z>0\), hence better against vanishing gradients

    • Problem: for very negative inputs, the gradient is 0 and may never recover (dying ReLU)

    • Leaky ReLU has a small (0.01) gradient there to allow recovery

@interact
def plot_activations_derivative(function=functions):
    fig, ax = plt.subplots(figsize=(6,2))
    plot_activation(function, ax, derivative=True)
    plt.legend(['original','derivative'], loc='upper center', 
               bbox_to_anchor=(0.5, 1.25), ncol=2)
if not interactive:
    fig, axes = plt.subplots(1,4, figsize=(10,2))
    for function, ax in zip(functions,axes):
        plot_activation(function, ax, derivative=True)
    fig.legend(['original','derivative'], loc='upper center', 
               bbox_to_anchor=(0.5, 1.25), ncol=2)
    plt.tight_layout();
../_images/08 - Neural Networks_34_0.png

ReLU vs Tanh

  • What is the effect of using non-smooth activation functions?

    • ReLU produces piecewise-linear boundaries, but allows deeper networks

    • Tanh produces smoother decision boundaries, but is slower

from sklearn.neural_network import MLPClassifier
from sklearn.datasets import make_moons
from sklearn.model_selection import train_test_split
from mglearn.plot_2d_separator import plot_2d_classification
import time

@interact
def plot_boundary(nr_layers=(1,4,1)):
    X, y = make_moons(n_samples=100, noise=0.25, random_state=3)
    X_train, X_test, y_train, y_test = train_test_split(X, y, stratify=y,
                                                        random_state=42)
    
    # Multi-Layer Perceptron with ReLU
    mlp = MLPClassifier(solver='lbfgs', random_state=0,
                        hidden_layer_sizes=[10]*nr_layers)
    start = time.time()
    mlp.fit(X_train, y_train)
    relu_time = time.time() - start
    relu_acc = mlp.score(X_test, y_test)

    # Multi-Layer Perceptron with tanh
    mlp_tanh = MLPClassifier(solver='lbfgs', activation='tanh',
                             random_state=0, hidden_layer_sizes=[10]*nr_layers)
    start = time.time()
    mlp_tanh.fit(X_train, y_train)
    tanh_time = time.time() - start
    tanh_acc = mlp_tanh.score(X_test, y_test)

    fig, axes = plt.subplots(1, 2, figsize=(8, 4))
    axes[0].scatter(X_train[:, 0], X_train[:, 1], c=y_train, cmap='bwr', label="train")
    axes[0].set_title("ReLU, acc: {:.2f}, time: {:.2f} sec".format(relu_acc, relu_time))
    plot_2d_classification(mlp, X_train, fill=True, cm='bwr', alpha=.3, ax=axes[0])
    axes[1].scatter(X_train[:, 0], X_train[:, 1], c=y_train, cmap='bwr', label="train")
    axes[1].set_title("tanh, acc: {:.2f}, time: {:.2f} sec".format(tanh_acc, tanh_time))
    plot_2d_classification(mlp_tanh, X_train, fill=True, cm='bwr', alpha=.3, ax=axes[1])
if not interactive:
    plot_boundary(nr_layers=2)
../_images/08 - Neural Networks_37_0.png

Activation functions for output layer

  • sigmoid converts output to probability in [0,1]

    • For binary classification

  • softmax converts all outputs (aka ‘logits’) to probabilities that sum up to 1

    • For multi-class classification (\(k\) classes)

    • Can cause over-confident models. If so, smooth the labels: \(y_{smooth} = (1-\alpha)y + \frac{\alpha}{k}\) $\(\text{softmax}(\mathbf{x},i) = \frac{e^{x_i}}{\sum_{j=1}^k e^{x_j}}\)$

  • For regression, don’t use any activation function, let the model learn the exact target

output_functions = ["sigmoid","softmax","none"]

@interact
def plot_output_activation(function=output_functions):
    fig, ax = plt.subplots(figsize=(6,2))
    plot_activation(function, ax)
if not interactive:
    fig, axes = plt.subplots(1,2, figsize=(8,2))
    for function, ax in zip(output_functions[:2],axes):
        plot_activation(function, ax)
    plt.tight_layout();
../_images/08 - Neural Networks_40_0.png

Weight initialization

  • Initializing weights to 0 is bad: all gradients in layer will be identical (symmetry)

  • Too small random weights shrink activations to 0 along the layers (vanishing gradient)

  • Too large random weights multiply along layers (exploding gradient, zig-zagging)

  • Ideal: small random weights + variance of input and output gradients remains the same

    • Glorot/Xavier initialization (for tanh): randomly sample from \(N(0,\sigma), \sigma = \sqrt{\frac{2}{\text{fan_in + fan_out}}}\)

      • fan_in: number of input units, fan_out: number of output units

    • He initialization (for ReLU): randomly sample from \(N(0,\sigma), \sigma = \sqrt{\frac{2}{\text{fan_in}}}\)

    • Uniform sampling (instead of \(N(0,\sigma)\)) for deeper networks (w.r.t. vanishing gradients)

fig, ax = plt.subplots(1,1, figsize=(6, 3))
draw_neural_net(ax, [3, 5, 5, 5, 5, 5, 3], random_weights=True, figsize=(6, 3))
../_images/08 - Neural Networks_42_0.png

Weight initialization: transfer learning

  • Instead of starting from scratch, start from weights previously learned from similar tasks

    • This is, to a big extent, how humans learn so fast

  • Transfer learning: learn weights on task T, transfer them to new network

    • Weights can be frozen, or finetuned to the new data

  • Only works if the previous task is ‘similar’ enough

    • Meta-learning: learn a good initialization across many related tasks

ml
## Code adapted from Il Gu Yi: https://github.com/ilguyi/optimizers.numpy
from matplotlib.colors import LogNorm
import tensorflow as tf
import tensorflow_addons as tfa

# Toy surface
def f(x, y):
    return (1.5 - x + x*y)**2 + (2.25 - x + x*y**2)**2 + (2.625 - x + x*y**3)**2

# Tensorflow optimizers
sgd = tf.optimizers.SGD(0.01)
lr_schedule = tf.optimizers.schedules.ExponentialDecay(0.02,decay_steps=100,decay_rate=0.96)
sgd_decay = tf.optimizers.SGD(learning_rate=lr_schedule)
#sgd_cyclic = tfa.optimizers.CyclicalLearningRate(initial_learning_rate= 0.1, 
#maximal_learning_rate= 0.5, step_size=0.05)
#clr_schedule = tfa.optimizers.CyclicalLearningRate(initial_learning_rate=1e-4, maximal_learning_rate= 0.1, 
#                                                   step_size=100, scale_fn=lambda x : x)
#sgd_cyclic = tf.optimizers.SGD(learning_rate=clr_schedule)
momentum = tf.optimizers.SGD(0.005, momentum=0.9, nesterov=False)
nesterov = tf.optimizers.SGD(0.005, momentum=0.9, nesterov=True)
adagrad = tf.optimizers.Adagrad(0.4)
#adamax = tf.optimizers.Adamax(learning_rate=0.5, beta_1=0.9, beta_2=0.999)
#adadelta = tf.optimizers.Adadelta(learning_rate=1.0)
rmsprop = tf.optimizers.RMSprop(learning_rate=0.1)
rmsprop_momentum = tf.optimizers.RMSprop(learning_rate=0.1, momentum=0.9)
adam = tf.optimizers.Adam(learning_rate=0.2, beta_1=0.9, beta_2=0.999, epsilon=1e-8)

optimizers = [sgd, sgd_decay, momentum, nesterov, adagrad, rmsprop,  rmsprop_momentum, adam]#, sgd_cyclic, adamax]
opt_names = ['sgd', 'sgd_decay', 'momentum', 'nesterov', 'adagrad', 'rmsprop', 'rmsprop_mom', 'adam']#, 'sgd_cyclic','adamax']
cmap = plt.cm.get_cmap('tab10')
colors = [cmap(x/10) for x in range(10)]

# Training
all_paths = []
for opt, name in zip(optimizers, opt_names):
    x_init = 0.8
    x = tf.compat.v1.get_variable('x', dtype=tf.float32, initializer=tf.constant(x_init))
    y_init = 1.6
    y = tf.compat.v1.get_variable('y', dtype=tf.float32, initializer=tf.constant(y_init))

    x_history = []
    y_history = []
    z_prev = 0.0
    max_steps = 100
    for step in range(max_steps):
        with tf.GradientTape() as g:
            z = f(x, y)
            x_history.append(x.numpy())
            y_history.append(y.numpy())
            dz_dx, dz_dy = g.gradient(z, [x, y])
            opt.apply_gradients(zip([dz_dx, dz_dy], [x, y]))

    if np.abs(z_prev - z.numpy()) < 1e-6:
        break
    z_prev = z.numpy()
    x_history = np.array(x_history)
    y_history = np.array(y_history)
    path = np.concatenate((np.expand_dims(x_history, 1), np.expand_dims(y_history, 1)), axis=1).T
    all_paths.append(path)
        
# Plotting
number_of_points = 50
margin = 4.5
minima = np.array([3., .5])
minima_ = minima.reshape(-1, 1)
x_min = 0. - 2
x_max = 0. + 3.5
y_min = 0. - 3.5
y_max = 0. + 2
x_points = np.linspace(x_min, x_max, number_of_points) 
y_points = np.linspace(y_min, y_max, number_of_points)
x_mesh, y_mesh = np.meshgrid(x_points, y_points)
z = np.array([f(xps, yps) for xps, yps in zip(x_mesh, y_mesh)])

def plot_optimizers(ax, iterations, optimizers):
    ax.contour(x_mesh, y_mesh, z, levels=np.logspace(-0.5, 5, 25), norm=LogNorm(), cmap=plt.cm.jet)
    ax.plot(*minima, 'r*', markersize=20)
    for name, path, color in zip(opt_names, all_paths, colors):
        if name in optimizers:
            p = path[:,:iterations]
            ax.quiver(p[0,:-1], p[1,:-1], p[0,1:]-p[0,:-1], p[1,1:]-p[1,:-1], scale_units='xy', angles='xy', scale=1, color=color, lw=3)
            ax.plot([], [], color=color, label=name, lw=3, linestyle='-')

    ax.set_xlabel('x')
    ax.set_ylabel('y')
    ax.set_xlim((x_min, x_max))
    ax.set_ylim((y_min, y_max))
    ax.legend(loc='lower left', prop={'size': 15})
    plt.tight_layout()
2022-03-16 12:43:05.283936: I tensorflow/core/common_runtime/pluggable_device/pluggable_device_factory.cc:305] Could not identify NUMA node of platform GPU ID 0, defaulting to 0. Your kernel may not have been built with NUMA support.
2022-03-16 12:43:05.284367: I tensorflow/core/common_runtime/pluggable_device/pluggable_device_factory.cc:271] Created TensorFlow device (/job:localhost/replica:0/task:0/device:GPU:0 with 0 MB memory) -> physical PluggableDevice (device: 0, name: METAL, pci bus id: <undefined>)
Metal device set to: Apple M1 Pro
# Toy plot to illustrate nesterov momentum
# TODO: replace with actual gradient computation?
def plot_nesterov(ax, method="Nesterov momentum"):
    ax.contour(x_mesh, y_mesh, z, levels=np.logspace(-0.5, 5, 25), norm=LogNorm(), cmap=plt.cm.jet)
    ax.plot(*minima, 'r*', markersize=20)
    
    # toy example
    ax.quiver(-0.8,-1.13,1,1.33, scale_units='xy', angles='xy', scale=1, color='k', alpha=0.5, lw=3, label="previous update")
    # 0.9 * previous update
    ax.quiver(0.2,0.2,0.9,1.2, scale_units='xy', angles='xy', scale=1, color='g', lw=3, label="momentum step")
    if method == "Momentum":
        ax.quiver(0.2,0.2,0.5,0, scale_units='xy', angles='xy', scale=1, color='r', lw=3, label="gradient step")
        ax.quiver(0.2,0.2,0.9*0.9+0.5,1.2, scale_units='xy', angles='xy', scale=1, color='b', lw=3, label="actual step")
    if method == "Nesterov momentum":
        ax.quiver(1.1,1.4,-0.2,-1, scale_units='xy', angles='xy', scale=1, color='r', lw=3, label="'lookahead' gradient step")
        ax.quiver(0.2,0.2,0.7,0.2, scale_units='xy', angles='xy', scale=1, color='b', lw=3, label="actual step")

    ax.set_xlabel('x')
    ax.set_ylabel('y')
    ax.set_title(method)
    ax.set_xlim((x_min, x_max))
    ax.set_ylim((-2.5, y_max))
    ax.legend(loc='lower right', prop={'size': 9})
    plt.tight_layout()

Optimizers

SGD with learning rate schedules

  • Using a constant learning \(\eta\) rate for weight updates \(\mathbf{w}_{(s+1)} = \mathbf{w}_s-\eta\nabla \mathcal{L}(\mathbf{w}_s)\) is not ideal

  • Learning rate decay/annealing with decay rate \(k\)

    • E.g. exponential (\(\eta_{s+1} = \eta_{s} e^{-ks}\)), inverse-time (\(\eta_{s+1} = \frac{\eta_{0}}{1+ks}\)),…

  • Cyclical learning rates

    • Change from small to large: hopefully in ‘good’ region long enough before diverging

    • Warm restarts: aggressive decay + reset to initial learning rate

@interact
def compare_optimizers(iterations=(1,100,1), optimizer1=opt_names, optimizer2=opt_names):
    fig, ax = plt.subplots(figsize=(6,4))
    plot_optimizers(ax,iterations,[optimizer1,optimizer2])
if not interactive:
    fig, axes = plt.subplots(1,2, figsize=(10,3))
    optimizers = ['sgd_decay', 'sgd_cyclic']
    for function, ax in zip(optimizers,axes):
        plot_optimizers(ax,100,function)
    plt.tight_layout();
../_images/08 - Neural Networks_48_0.png

Momentum

  • Imagine a ball rolling downhill: accumulates momentum, doesn’t exactly follow steepest descent

    • Reduces oscillation, follows larger (consistent) gradient of the loss surface

  • Adds a velocity vector \(\mathbf{v}\) with momentum \(\gamma\) (e.g. 0.9, or increase from \(\gamma=0.5\) to \(\gamma=0.99\)) $\(\mathbf{w}_{(s+1)} = \mathbf{w}_{(s)} + \mathbf{v}_{(s)} \qquad \text{with} \qquad \color{blue}{\mathbf{v}_{(s)}} = \color{green}{\gamma \mathbf{v}_{(s-1)}} - \color{red}{\eta \nabla \mathcal{L}(\mathbf{w}_{(s)})}\)$

  • Nesterov momentum: Look where momentum step would bring you, compute gradient there

    • Responds faster (and reduces momentum) when the gradient changes $\(\color{blue}{\mathbf{v}_{(s)}} = \color{green}{\gamma \mathbf{v}_{(s-1)}} - \color{red}{\eta \nabla \mathcal{L}(\mathbf{w}_{(s)} + \gamma \mathbf{v}_{(s-1)})}\)$

fig, axes = plt.subplots(1,2, figsize=(10,2.6))
plot_nesterov(axes[0],method="Momentum")
plot_nesterov(axes[1],method="Nesterov momentum")
../_images/08 - Neural Networks_50_0.png

Momentum in practice

@interact
def compare_optimizers(iterations=(1,100,1), optimizer1=opt_names, optimizer2=opt_names):
    fig, ax = plt.subplots(figsize=(6,4))
    plot_optimizers(ax,iterations,[optimizer1,optimizer2])
if not interactive:
    fig, axes = plt.subplots(1,2, figsize=(10,3.5))
    optimizers = [['sgd','momentum'], ['momentum','nesterov']]
    for function, ax in zip(optimizers,axes):
        plot_optimizers(ax,100,function)
    plt.tight_layout();
../_images/08 - Neural Networks_53_0.png

Adaptive gradients

  • ‘Correct’ the learning rate for each \(w_i\) based on specific local conditions (layer depth, fan-in,…)

  • Adagrad: scale \(\eta\) according to squared sum of previous gradients \(G_{i,(s)} = \sum_{t=1}^s \mathcal{L}(w_{i,(t)})^2\)

    • Update rule for \(w_i\). Usually \(\epsilon=10^{-7}\) (avoids division by 0), \(\eta=0.001\). $\(w_{i,(s+1)} = w_{i,(s)} - \frac{\eta}{\sqrt{G_{i,(s)}+\epsilon}} \nabla \mathcal{L}(w_{i,(s)})\)$

  • RMSProp: use moving average of squared gradients \(m_{i,(s)} = \gamma m_{i,(s-1)} + (1-\gamma) \nabla \mathcal{L}(w_{i,(s)})^2\)

    • Avoids that gradients dwindle to 0 as \(G_{i,(s)}\) grows. Usually \(\gamma=0.9, \eta=0.001\) $\(w_{i,(s+1)} = w_{i,(s)}- \frac{\eta}{\sqrt{m_{i,(s)}+\epsilon}} \nabla \mathcal{L}(w_{i,(s)})\)$

if not interactive:
    fig, axes = plt.subplots(1,2, figsize=(10,2.6))
    optimizers = [['sgd','adagrad', 'rmsprop'], ['rmsprop','rmsprop_mom']]
    for function, ax in zip(optimizers,axes):
        plot_optimizers(ax,100,function)
    plt.tight_layout();
../_images/08 - Neural Networks_55_0.png
@interact
def compare_optimizers(iterations=(1,100,1), optimizer1=opt_names, optimizer2=opt_names):
    fig, ax = plt.subplots(figsize=(6,4))
    plot_optimizers(ax,iterations,[optimizer1,optimizer2])

Adam (Adaptive moment estimation)

  • Adam: RMSProp + momentum. Adds moving average for gradients as well (\(\gamma_2\) = momentum):

    • Adds a bias correction to avoid small initial gradients: \(\hat{m}_{i,(s)} = \frac{m_{i,(s)}}{1-\gamma}\) and \(\hat{g}_{i,(s)} = \frac{g_{i,(s)}}{1-\gamma_2}\) $\(g_{i,(s)} = \gamma_2 g_{i,(s-1)} + (1-\gamma_2) \nabla \mathcal{L}(w_{i,(s)})\)\( \)\(w_{i,(s+1)} = w_{i,(s)}- \frac{\eta}{\sqrt{\hat{m}_{i,(s)}+\epsilon}} \hat{g}_{i,(s)}\)$

  • Adamax: Idem, but use max() instead of moving average: \(u_{i,(s)} = max(\gamma u_{i,(s-1)}, |\mathcal{L}(w_{i,(s)})|)\) $\(w_{i,(s+1)} = w_{i,(s)}- \frac{\eta}{u_{i,(s)}} \hat{g}_{i,(s)}\)$

if not interactive:
    fig, axes = plt.subplots(1,2, figsize=(10,2.6))
    optimizers = [['sgd','adam'], ['adam','adamax']]
    for function, ax in zip(optimizers,axes):
        plot_optimizers(ax,100,function)
    plt.tight_layout();
../_images/08 - Neural Networks_58_0.png
@interact
def compare_optimizers(iterations=(1,100,1), optimizer1=opt_names, optimizer2=opt_names):
    fig, ax = plt.subplots(figsize=(6,4))
    plot_optimizers(ax,iterations,[optimizer1,optimizer2])

SGD Optimizer Zoo

  • RMSProp often works well, but do try alternatives. For even more optimizers, see here.

if not interactive:
    fig, ax = plt.subplots(1,1, figsize=(10,5.5))
    plot_optimizers(ax,100,opt_names)
../_images/08 - Neural Networks_61_0.png
@interact
def compare_optimizers(iterations=(1,100,1)):
    fig, ax = plt.subplots(figsize=(10,6))
    plot_optimizers(ax,iterations,opt_names)
from tensorflow.keras import models
from tensorflow.keras import layers
from numpy.random import seed
from tensorflow.random import set_seed
import random
import os

#Trying to set all seeds
os.environ['PYTHONHASHSEED']=str(0)
random.seed(0)
seed(0)
set_seed(0)
seed_value= 0

Neural networks in practice

  • There are many practical courses on training neural nets. E.g.:

  • Here, we’ll use Keras, a general API for building neural networks

    • Default API for TensorFlow, also has backends for CNTK, Theano

  • Focus on key design decisions, evaluation, and regularization

  • Running example: Fashion-MNIST

    • 28x28 pixel images of 10 classes of fashion items

# Download FMINST data. Takes a while the first time.
mnist = oml.datasets.get_dataset(40996)
X, y, _, _ = mnist.get_data(target=mnist.default_target_attribute, dataset_format='array');
X = X.reshape(70000, 28, 28)
fmnist_classes = {0:"T-shirt/top", 1: "Trouser", 2: "Pullover", 3: "Dress", 4: "Coat", 5: "Sandal", 
                  6: "Shirt", 7: "Sneaker", 8: "Bag", 9: "Ankle boot"}

# Take some random examples
from random import randint
fig, axes = plt.subplots(1, 5,  figsize=(10, 5))
for i in range(5):
    n = randint(0,70000)
    axes[i].imshow(X[n], cmap=plt.cm.gray_r)
    axes[i].set_xticks([])
    axes[i].set_yticks([])
    axes[i].set_xlabel("{}".format(fmnist_classes[y[n]]))
plt.show();
../_images/08 - Neural Networks_65_0.png

Building the network

  • We first build a simple sequential model (no branches)

  • Input layer (‘input_shape’): a flat vector of 28*28=784 nodes

    • We’ll see how to properly deal with images later

  • Two dense hidden layers: 512 nodes each, ReLU activation

    • Glorot weight initialization is applied by default

  • Output layer: 10 nodes (for 10 classes) and softmax activation

network = models.Sequential()
network.add(layers.Dense(512, activation='relu', kernel_initializer='he_normal', input_shape=(28 * 28,)))
network.add(layers.Dense(512, activation='relu', kernel_initializer='he_normal'))
network.add(layers.Dense(10, activation='softmax'))
from tensorflow.keras import initializers

network = models.Sequential()
network.add(layers.Dense(512, activation='relu', kernel_initializer='he_normal', input_shape=(28 * 28,)))
network.add(layers.Dense(512, activation='relu', kernel_initializer='he_normal'))
network.add(layers.Dense(10, activation='softmax'))

Model summary

  • Lots of parameters (weights and biases) to learn!

    • hidden layer 1 : (28 * 28 + 1) * 512 = 401920

    • hidden layer 2 : (512 + 1) * 512 = 262656

    • output layer: (512 + 1) * 10 = 5130

network.summary()
network.summary()
Model: "sequential"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
=================================================================
 dense (Dense)               (None, 512)               401920    
                                                                 
 dense_1 (Dense)             (None, 512)               262656    
                                                                 
 dense_2 (Dense)             (None, 10)                5130      
                                                                 
=================================================================
Total params: 669,706
Trainable params: 669,706
Non-trainable params: 0
_________________________________________________________________

Choosing loss, optimizer, metrics

  • Loss function

    • Cross-entropy (log loss) for multi-class classification (\(y_{true}\) is one-hot encoded)

    • Use binary crossentropy for binary problems (single output node)

    • Use sparse categorical crossentropy if \(y_{true}\) is label-encoded (1,2,3,…)

  • Optimizer

    • Any of the optimizers we discussed before. RMSprop usually works well.

  • Metrics

    • To monitor performance during training and testing, e.g. accuracy

# Shorthand
network.compile(loss='categorical_crossentropy', optimizer='rmsprop', metrics=['accuracy'])
# Detailed
network.compile(loss=CategoricalCrossentropy(label_smoothing=0.01),
                optimizer=RMSprop(learning_rate=0.001, momentum=0.0)
                metrics=[Accuracy()])
from tensorflow.keras.optimizers import RMSprop
from tensorflow.keras.losses import CategoricalCrossentropy
from tensorflow.keras.metrics import Accuracy

network.compile(loss='categorical_crossentropy', optimizer='rmsprop', metrics=['accuracy'])

Preprocessing: Normalization, Reshaping, Encoding

  • Always normalize (standardize or min-max) the inputs. Mean should be close to 0.

    • Avoid that some inputs overpower others

    • Speed up convergence

      • Gradients of activation functions \(\frac{\partial a_{h}}{\partial z_{h}}\) are (near) 0 for large inputs

      • If some gradients become much larger than others, SGD will start zig-zagging

  • Reshape the data to fit the shape of the input layer, e.g. (n, 28*28) or (n, 28,28)

    • Tensor with instances in first dimension, rest must match the input layer

  • In multi-class classification, every class is an output node, so one-hot-encode the labels

    • e.g. class ‘4’ becomes [0,0,0,0,1,0,0,0,0,0]

X = X.astype('float32') / 255
X = X.reshape((60000, 28 * 28))
y = to_categorical(y)
from sklearn.model_selection import train_test_split
Xf_train, Xf_test, yf_train, yf_test = train_test_split(X, y, train_size=60000, shuffle=True, random_state=0)

Xf_train = Xf_train.reshape((60000, 28 * 28))
Xf_test = Xf_test.reshape((10000, 28 * 28))

# TODO: check if standardization works better
Xf_train = Xf_train.astype('float32') / 255
Xf_test = Xf_test.astype('float32') / 255

from tensorflow.keras.utils import to_categorical
yf_train = to_categorical(yf_train)
yf_test = to_categorical(yf_test)

Choosing training hyperparameters

  • Number of epochs: enough to allow convergence

    • Too much: model starts overfitting (or just wastes time)

  • Batch size: small batches (e.g. 32, 64,… samples) often preferred

    • ‘Noisy’ training data makes overfitting less likely

      • Larger batches generalize less well (‘generalization gap’)

    • Requires less memory (especially in GPUs)

    • Large batches do speed up training, may converge in fewer epochs

  • Batch size interacts with learning rate

    • Instead of shrinking the learning rate you can increase batch size

history = network.fit(X_train, y_train, epochs=3, batch_size=32);
history = network.fit(Xf_train, yf_train, epochs=3, batch_size=32);
2022-03-16 12:43:34.490590: W tensorflow/core/platform/profile_utils/cpu_utils.cc:128] Failed to get CPU frequency: 0 Hz
Epoch 1/3
2022-03-16 12:43:34.772914: I tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.cc:112] Plugin optimizer for device_type GPU is enabled.
   1/1875 [..............................] - ETA: 15:39 - loss: 2.4362 - accuracy: 0.0312

   9/1875 [..............................] - ETA: 13s - loss: 2.0937 - accuracy: 0.3472  

  17/1875 [..............................] - ETA: 12s - loss: 1.5483 - accuracy: 0.5018

  25/1875 [..............................] - ETA: 12s - loss: 1.3969 - accuracy: 0.5375

  34/1875 [..............................] - ETA: 11s - loss: 1.2232 - accuracy: 0.5855

  42/1875 [..............................] - ETA: 11s - loss: 1.1373 - accuracy: 0.6124

  51/1875 [..............................] - ETA: 11s - loss: 1.0716 - accuracy: 0.6330

  60/1875 [..............................] - ETA: 11s - loss: 1.0320 - accuracy: 0.6422

  68/1875 [>.............................] - ETA: 11s - loss: 0.9942 - accuracy: 0.6558

  77/1875 [>.............................] - ETA: 11s - loss: 0.9666 - accuracy: 0.6664

  85/1875 [>.............................] - ETA: 11s - loss: 0.9478 - accuracy: 0.6724

  94/1875 [>.............................] - ETA: 11s - loss: 0.9263 - accuracy: 0.6785

 102/1875 [>.............................] - ETA: 11s - loss: 0.8979 - accuracy: 0.6857

 111/1875 [>.............................] - ETA: 11s - loss: 0.8843 - accuracy: 0.6914

 119/1875 [>.............................] - ETA: 11s - loss: 0.8619 - accuracy: 0.6993

 127/1875 [=>............................] - ETA: 11s - loss: 0.8515 - accuracy: 0.7030

 135/1875 [=>............................] - ETA: 10s - loss: 0.8332 - accuracy: 0.7088

 143/1875 [=>............................] - ETA: 10s - loss: 0.8263 - accuracy: 0.7120

 151/1875 [=>............................] - ETA: 11s - loss: 0.8147 - accuracy: 0.7146

 159/1875 [=>............................] - ETA: 10s - loss: 0.8036 - accuracy: 0.7188

 167/1875 [=>............................] - ETA: 11s - loss: 0.7937 - accuracy: 0.7214

 175/1875 [=>............................] - ETA: 11s - loss: 0.7858 - accuracy: 0.7236

 183/1875 [=>............................] - ETA: 10s - loss: 0.7803 - accuracy: 0.7249

 191/1875 [==>...........................] - ETA: 10s - loss: 0.7702 - accuracy: 0.7279

 199/1875 [==>...........................] - ETA: 10s - loss: 0.7650 - accuracy: 0.7283

 207/1875 [==>...........................] - ETA: 10s - loss: 0.7608 - accuracy: 0.7296

 215/1875 [==>...........................] - ETA: 10s - loss: 0.7579 - accuracy: 0.7301

 223/1875 [==>...........................] - ETA: 10s - loss: 0.7550 - accuracy: 0.7298

 231/1875 [==>...........................] - ETA: 10s - loss: 0.7467 - accuracy: 0.7321

 240/1875 [==>...........................] - ETA: 10s - loss: 0.7414 - accuracy: 0.7337

 248/1875 [==>...........................] - ETA: 10s - loss: 0.7377 - accuracy: 0.7341

 256/1875 [===>..........................] - ETA: 10s - loss: 0.7344 - accuracy: 0.7355

 264/1875 [===>..........................] - ETA: 10s - loss: 0.7305 - accuracy: 0.7377

 273/1875 [===>..........................] - ETA: 10s - loss: 0.7260 - accuracy: 0.7378

 281/1875 [===>..........................] - ETA: 10s - loss: 0.7195 - accuracy: 0.7409

 289/1875 [===>..........................] - ETA: 10s - loss: 0.7128 - accuracy: 0.7426

 297/1875 [===>..........................] - ETA: 10s - loss: 0.7095 - accuracy: 0.7440

 305/1875 [===>..........................] - ETA: 10s - loss: 0.7024 - accuracy: 0.7468

 314/1875 [====>.........................] - ETA: 10s - loss: 0.6995 - accuracy: 0.7483

 323/1875 [====>.........................] - ETA: 10s - loss: 0.6932 - accuracy: 0.7501

 331/1875 [====>.........................] - ETA: 10s - loss: 0.6903 - accuracy: 0.7517

 339/1875 [====>.........................] - ETA: 9s - loss: 0.6874 - accuracy: 0.7526 

 347/1875 [====>.........................] - ETA: 9s - loss: 0.6812 - accuracy: 0.7551

 356/1875 [====>.........................] - ETA: 9s - loss: 0.6818 - accuracy: 0.7558

 364/1875 [====>.........................] - ETA: 9s - loss: 0.6767 - accuracy: 0.7582

 372/1875 [====>.........................] - ETA: 9s - loss: 0.6722 - accuracy: 0.7594

 380/1875 [=====>........................] - ETA: 9s - loss: 0.6695 - accuracy: 0.7596

 388/1875 [=====>........................] - ETA: 9s - loss: 0.6675 - accuracy: 0.7608

 396/1875 [=====>........................] - ETA: 9s - loss: 0.6649 - accuracy: 0.7614

 404/1875 [=====>........................] - ETA: 9s - loss: 0.6628 - accuracy: 0.7617

 412/1875 [=====>........................] - ETA: 9s - loss: 0.6593 - accuracy: 0.7628

 420/1875 [=====>........................] - ETA: 9s - loss: 0.6584 - accuracy: 0.7632

 428/1875 [=====>........................] - ETA: 9s - loss: 0.6567 - accuracy: 0.7638

 436/1875 [=====>........................] - ETA: 9s - loss: 0.6537 - accuracy: 0.7645

 444/1875 [======>.......................] - ETA: 9s - loss: 0.6517 - accuracy: 0.7653

 452/1875 [======>.......................] - ETA: 9s - loss: 0.6506 - accuracy: 0.7658

 458/1875 [======>.......................] - ETA: 9s - loss: 0.6476 - accuracy: 0.7669

 466/1875 [======>.......................] - ETA: 9s - loss: 0.6442 - accuracy: 0.7679

 474/1875 [======>.......................] - ETA: 9s - loss: 0.6428 - accuracy: 0.7683

 482/1875 [======>.......................] - ETA: 9s - loss: 0.6409 - accuracy: 0.7691

 490/1875 [======>.......................] - ETA: 9s - loss: 0.6368 - accuracy: 0.7706

 498/1875 [======>.......................] - ETA: 8s - loss: 0.6359 - accuracy: 0.7708

 506/1875 [=======>......................] - ETA: 8s - loss: 0.6340 - accuracy: 0.7713

 515/1875 [=======>......................] - ETA: 8s - loss: 0.6351 - accuracy: 0.7706

 524/1875 [=======>......................] - ETA: 8s - loss: 0.6330 - accuracy: 0.7714

 532/1875 [=======>......................] - ETA: 8s - loss: 0.6308 - accuracy: 0.7723

 541/1875 [=======>......................] - ETA: 8s - loss: 0.6291 - accuracy: 0.7729

 550/1875 [=======>......................] - ETA: 8s - loss: 0.6272 - accuracy: 0.7737

 559/1875 [=======>......................] - ETA: 8s - loss: 0.6248 - accuracy: 0.7748

 562/1875 [=======>......................] - ETA: 8s - loss: 0.6250 - accuracy: 0.7748

 570/1875 [========>.....................] - ETA: 8s - loss: 0.6232 - accuracy: 0.7751

 578/1875 [========>.....................] - ETA: 8s - loss: 0.6221 - accuracy: 0.7756

 584/1875 [========>.....................] - ETA: 8s - loss: 0.6207 - accuracy: 0.7762

 592/1875 [========>.....................] - ETA: 8s - loss: 0.6212 - accuracy: 0.7761

 600/1875 [========>.....................] - ETA: 8s - loss: 0.6192 - accuracy: 0.7768

 608/1875 [========>.....................] - ETA: 8s - loss: 0.6177 - accuracy: 0.7774

 616/1875 [========>.....................] - ETA: 8s - loss: 0.6161 - accuracy: 0.7777

 624/1875 [========>.....................] - ETA: 8s - loss: 0.6159 - accuracy: 0.7781

 632/1875 [=========>....................] - ETA: 8s - loss: 0.6132 - accuracy: 0.7791

 640/1875 [=========>....................] - ETA: 8s - loss: 0.6123 - accuracy: 0.7797

 648/1875 [=========>....................] - ETA: 8s - loss: 0.6109 - accuracy: 0.7800

 656/1875 [=========>....................] - ETA: 8s - loss: 0.6085 - accuracy: 0.7808

 664/1875 [=========>....................] - ETA: 7s - loss: 0.6072 - accuracy: 0.7810

 672/1875 [=========>....................] - ETA: 7s - loss: 0.6050 - accuracy: 0.7814

 680/1875 [=========>....................] - ETA: 7s - loss: 0.6043 - accuracy: 0.7816

 688/1875 [==========>...................] - ETA: 7s - loss: 0.6027 - accuracy: 0.7822

 696/1875 [==========>...................] - ETA: 7s - loss: 0.6013 - accuracy: 0.7826

 704/1875 [==========>...................] - ETA: 7s - loss: 0.6003 - accuracy: 0.7828

 712/1875 [==========>...................] - ETA: 7s - loss: 0.5991 - accuracy: 0.7835

 720/1875 [==========>...................] - ETA: 7s - loss: 0.5981 - accuracy: 0.7841

 728/1875 [==========>...................] - ETA: 7s - loss: 0.5966 - accuracy: 0.7845

 736/1875 [==========>...................] - ETA: 7s - loss: 0.5962 - accuracy: 0.7847

 744/1875 [==========>...................] - ETA: 7s - loss: 0.5944 - accuracy: 0.7855

 752/1875 [===========>..................] - ETA: 7s - loss: 0.5929 - accuracy: 0.7860

 760/1875 [===========>..................] - ETA: 7s - loss: 0.5916 - accuracy: 0.7866

 768/1875 [===========>..................] - ETA: 7s - loss: 0.5912 - accuracy: 0.7867

 776/1875 [===========>..................] - ETA: 7s - loss: 0.5907 - accuracy: 0.7868

 784/1875 [===========>..................] - ETA: 7s - loss: 0.5893 - accuracy: 0.7873

 792/1875 [===========>..................] - ETA: 7s - loss: 0.5878 - accuracy: 0.7878

 800/1875 [===========>..................] - ETA: 7s - loss: 0.5867 - accuracy: 0.7886

 808/1875 [===========>..................] - ETA: 7s - loss: 0.5860 - accuracy: 0.7889

 816/1875 [============>.................] - ETA: 6s - loss: 0.5847 - accuracy: 0.7892

 824/1875 [============>.................] - ETA: 6s - loss: 0.5836 - accuracy: 0.7897

 832/1875 [============>.................] - ETA: 6s - loss: 0.5820 - accuracy: 0.7903

 840/1875 [============>.................] - ETA: 6s - loss: 0.5812 - accuracy: 0.7906

 848/1875 [============>.................] - ETA: 6s - loss: 0.5810 - accuracy: 0.7908

 856/1875 [============>.................] - ETA: 6s - loss: 0.5794 - accuracy: 0.7913

 864/1875 [============>.................] - ETA: 6s - loss: 0.5786 - accuracy: 0.7914

 872/1875 [============>.................] - ETA: 6s - loss: 0.5781 - accuracy: 0.7919

 880/1875 [=============>................] - ETA: 6s - loss: 0.5777 - accuracy: 0.7919

 888/1875 [=============>................] - ETA: 6s - loss: 0.5764 - accuracy: 0.7925

 896/1875 [=============>................] - ETA: 6s - loss: 0.5758 - accuracy: 0.7928

 904/1875 [=============>................] - ETA: 6s - loss: 0.5761 - accuracy: 0.7928

 912/1875 [=============>................] - ETA: 6s - loss: 0.5763 - accuracy: 0.7926

 920/1875 [=============>................] - ETA: 6s - loss: 0.5749 - accuracy: 0.7931

 928/1875 [=============>................] - ETA: 6s - loss: 0.5738 - accuracy: 0.7932

 936/1875 [=============>................] - ETA: 6s - loss: 0.5731 - accuracy: 0.7933

 944/1875 [==============>...............] - ETA: 6s - loss: 0.5722 - accuracy: 0.7938

 952/1875 [==============>...............] - ETA: 6s - loss: 0.5714 - accuracy: 0.7941

 960/1875 [==============>...............] - ETA: 6s - loss: 0.5705 - accuracy: 0.7944

 968/1875 [==============>...............] - ETA: 5s - loss: 0.5703 - accuracy: 0.7944

 976/1875 [==============>...............] - ETA: 5s - loss: 0.5691 - accuracy: 0.7949

 984/1875 [==============>...............] - ETA: 5s - loss: 0.5683 - accuracy: 0.7951

 992/1875 [==============>...............] - ETA: 5s - loss: 0.5664 - accuracy: 0.7957

1000/1875 [===============>..............] - ETA: 5s - loss: 0.5655 - accuracy: 0.7960

1008/1875 [===============>..............] - ETA: 5s - loss: 0.5650 - accuracy: 0.7962

1016/1875 [===============>..............] - ETA: 5s - loss: 0.5641 - accuracy: 0.7965

1024/1875 [===============>..............] - ETA: 5s - loss: 0.5625 - accuracy: 0.7971

1032/1875 [===============>..............] - ETA: 5s - loss: 0.5614 - accuracy: 0.7975

1040/1875 [===============>..............] - ETA: 5s - loss: 0.5605 - accuracy: 0.7981

1048/1875 [===============>..............] - ETA: 5s - loss: 0.5593 - accuracy: 0.7985

1057/1875 [===============>..............] - ETA: 5s - loss: 0.5592 - accuracy: 0.7990

1065/1875 [================>.............] - ETA: 5s - loss: 0.5586 - accuracy: 0.7991

1073/1875 [================>.............] - ETA: 5s - loss: 0.5576 - accuracy: 0.7994

1081/1875 [================>.............] - ETA: 5s - loss: 0.5569 - accuracy: 0.7996

1089/1875 [================>.............] - ETA: 5s - loss: 0.5563 - accuracy: 0.7998

1098/1875 [================>.............] - ETA: 5s - loss: 0.5554 - accuracy: 0.7998

1106/1875 [================>.............] - ETA: 5s - loss: 0.5543 - accuracy: 0.8001

1114/1875 [================>.............] - ETA: 5s - loss: 0.5539 - accuracy: 0.8003

1122/1875 [================>.............] - ETA: 4s - loss: 0.5528 - accuracy: 0.8007

1130/1875 [=================>............] - ETA: 4s - loss: 0.5527 - accuracy: 0.8007

1138/1875 [=================>............] - ETA: 4s - loss: 0.5520 - accuracy: 0.8009

1146/1875 [=================>............] - ETA: 4s - loss: 0.5520 - accuracy: 0.8011

1154/1875 [=================>............] - ETA: 4s - loss: 0.5514 - accuracy: 0.8012

1162/1875 [=================>............] - ETA: 4s - loss: 0.5505 - accuracy: 0.8014

1170/1875 [=================>............] - ETA: 4s - loss: 0.5501 - accuracy: 0.8015

1178/1875 [=================>............] - ETA: 4s - loss: 0.5500 - accuracy: 0.8016

1186/1875 [=================>............] - ETA: 4s - loss: 0.5483 - accuracy: 0.8023

1195/1875 [==================>...........] - ETA: 4s - loss: 0.5473 - accuracy: 0.8026

1204/1875 [==================>...........] - ETA: 4s - loss: 0.5468 - accuracy: 0.8028

1213/1875 [==================>...........] - ETA: 4s - loss: 0.5467 - accuracy: 0.8028

1222/1875 [==================>...........] - ETA: 4s - loss: 0.5456 - accuracy: 0.8033

1231/1875 [==================>...........] - ETA: 4s - loss: 0.5450 - accuracy: 0.8034

1240/1875 [==================>...........] - ETA: 4s - loss: 0.5439 - accuracy: 0.8039

1249/1875 [==================>...........] - ETA: 4s - loss: 0.5434 - accuracy: 0.8040

1258/1875 [===================>..........] - ETA: 4s - loss: 0.5429 - accuracy: 0.8042

1267/1875 [===================>..........] - ETA: 3s - loss: 0.5423 - accuracy: 0.8042

1276/1875 [===================>..........] - ETA: 3s - loss: 0.5415 - accuracy: 0.8044

1285/1875 [===================>..........] - ETA: 3s - loss: 0.5409 - accuracy: 0.8046

1293/1875 [===================>..........] - ETA: 3s - loss: 0.5416 - accuracy: 0.8046

1301/1875 [===================>..........] - ETA: 3s - loss: 0.5408 - accuracy: 0.8049

1309/1875 [===================>..........] - ETA: 3s - loss: 0.5394 - accuracy: 0.8053

1318/1875 [====================>.........] - ETA: 3s - loss: 0.5392 - accuracy: 0.8056

1327/1875 [====================>.........] - ETA: 3s - loss: 0.5382 - accuracy: 0.8060

1336/1875 [====================>.........] - ETA: 3s - loss: 0.5371 - accuracy: 0.8063

1345/1875 [====================>.........] - ETA: 3s - loss: 0.5368 - accuracy: 0.8064

1354/1875 [====================>.........] - ETA: 3s - loss: 0.5358 - accuracy: 0.8068

1363/1875 [====================>.........] - ETA: 3s - loss: 0.5358 - accuracy: 0.8070

1371/1875 [====================>.........] - ETA: 3s - loss: 0.5355 - accuracy: 0.8072

1379/1875 [=====================>........] - ETA: 3s - loss: 0.5349 - accuracy: 0.8075

1387/1875 [=====================>........] - ETA: 3s - loss: 0.5344 - accuracy: 0.8078

1395/1875 [=====================>........] - ETA: 3s - loss: 0.5339 - accuracy: 0.8082

1403/1875 [=====================>........] - ETA: 3s - loss: 0.5338 - accuracy: 0.8083

1411/1875 [=====================>........] - ETA: 3s - loss: 0.5342 - accuracy: 0.8080

1419/1875 [=====================>........] - ETA: 2s - loss: 0.5337 - accuracy: 0.8083

1427/1875 [=====================>........] - ETA: 2s - loss: 0.5333 - accuracy: 0.8083

1435/1875 [=====================>........] - ETA: 2s - loss: 0.5333 - accuracy: 0.8084

1443/1875 [======================>.......] - ETA: 2s - loss: 0.5329 - accuracy: 0.8084

1451/1875 [======================>.......] - ETA: 2s - loss: 0.5326 - accuracy: 0.8084

1459/1875 [======================>.......] - ETA: 2s - loss: 0.5323 - accuracy: 0.8085

1467/1875 [======================>.......] - ETA: 2s - loss: 0.5326 - accuracy: 0.8084

1475/1875 [======================>.......] - ETA: 2s - loss: 0.5319 - accuracy: 0.8086

1483/1875 [======================>.......] - ETA: 2s - loss: 0.5312 - accuracy: 0.8089

1491/1875 [======================>.......] - ETA: 2s - loss: 0.5307 - accuracy: 0.8091

1499/1875 [======================>.......] - ETA: 2s - loss: 0.5298 - accuracy: 0.8095

1507/1875 [=======================>......] - ETA: 2s - loss: 0.5293 - accuracy: 0.8097

1514/1875 [=======================>......] - ETA: 2s - loss: 0.5285 - accuracy: 0.8099

1522/1875 [=======================>......] - ETA: 2s - loss: 0.5278 - accuracy: 0.8102

1530/1875 [=======================>......] - ETA: 2s - loss: 0.5278 - accuracy: 0.8103

1538/1875 [=======================>......] - ETA: 2s - loss: 0.5276 - accuracy: 0.8103

1546/1875 [=======================>......] - ETA: 2s - loss: 0.5269 - accuracy: 0.8105

1554/1875 [=======================>......] - ETA: 2s - loss: 0.5262 - accuracy: 0.8108

1562/1875 [=======================>......] - ETA: 2s - loss: 0.5259 - accuracy: 0.8109

1570/1875 [========================>.....] - ETA: 1s - loss: 0.5260 - accuracy: 0.8109

1579/1875 [========================>.....] - ETA: 1s - loss: 0.5250 - accuracy: 0.8112

1588/1875 [========================>.....] - ETA: 1s - loss: 0.5242 - accuracy: 0.8114

1597/1875 [========================>.....] - ETA: 1s - loss: 0.5241 - accuracy: 0.8115

1606/1875 [========================>.....] - ETA: 1s - loss: 0.5231 - accuracy: 0.8118

1615/1875 [========================>.....] - ETA: 1s - loss: 0.5227 - accuracy: 0.8120

1624/1875 [========================>.....] - ETA: 1s - loss: 0.5219 - accuracy: 0.8123

1633/1875 [=========================>....] - ETA: 1s - loss: 0.5216 - accuracy: 0.8126

1642/1875 [=========================>....] - ETA: 1s - loss: 0.5214 - accuracy: 0.8127

1651/1875 [=========================>....] - ETA: 1s - loss: 0.5212 - accuracy: 0.8128

1660/1875 [=========================>....] - ETA: 1s - loss: 0.5204 - accuracy: 0.8130

1669/1875 [=========================>....] - ETA: 1s - loss: 0.5206 - accuracy: 0.8130

1677/1875 [=========================>....] - ETA: 1s - loss: 0.5199 - accuracy: 0.8133

1685/1875 [=========================>....] - ETA: 1s - loss: 0.5190 - accuracy: 0.8136

1694/1875 [==========================>...] - ETA: 1s - loss: 0.5188 - accuracy: 0.8138

1703/1875 [==========================>...] - ETA: 1s - loss: 0.5182 - accuracy: 0.8142

1711/1875 [==========================>...] - ETA: 1s - loss: 0.5174 - accuracy: 0.8144

1720/1875 [==========================>...] - ETA: 1s - loss: 0.5171 - accuracy: 0.8146

1729/1875 [==========================>...] - ETA: 0s - loss: 0.5169 - accuracy: 0.8146

1738/1875 [==========================>...] - ETA: 0s - loss: 0.5165 - accuracy: 0.8146

1746/1875 [==========================>...] - ETA: 0s - loss: 0.5162 - accuracy: 0.8148

1755/1875 [===========================>..] - ETA: 0s - loss: 0.5156 - accuracy: 0.8150

1764/1875 [===========================>..] - ETA: 0s - loss: 0.5149 - accuracy: 0.8152

1773/1875 [===========================>..] - ETA: 0s - loss: 0.5149 - accuracy: 0.8153

1781/1875 [===========================>..] - ETA: 0s - loss: 0.5148 - accuracy: 0.8154

1790/1875 [===========================>..] - ETA: 0s - loss: 0.5146 - accuracy: 0.8155

1799/1875 [===========================>..] - ETA: 0s - loss: 0.5144 - accuracy: 0.8155

1808/1875 [===========================>..] - ETA: 0s - loss: 0.5137 - accuracy: 0.8157

1817/1875 [============================>.] - ETA: 0s - loss: 0.5140 - accuracy: 0.8159

1825/1875 [============================>.] - ETA: 0s - loss: 0.5130 - accuracy: 0.8162

1834/1875 [============================>.] - ETA: 0s - loss: 0.5130 - accuracy: 0.8162

1843/1875 [============================>.] - ETA: 0s - loss: 0.5123 - accuracy: 0.8166

1852/1875 [============================>.] - ETA: 0s - loss: 0.5119 - accuracy: 0.8168

1861/1875 [============================>.] - ETA: 0s - loss: 0.5114 - accuracy: 0.8170

1870/1875 [============================>.] - ETA: 0s - loss: 0.5106 - accuracy: 0.8172

1875/1875 [==============================] - 13s 6ms/step - loss: 0.5103 - accuracy: 0.8174
Epoch 2/3

   1/1875 [..............................] - ETA: 12s - loss: 0.1696 - accuracy: 0.9375

  10/1875 [..............................] - ETA: 11s - loss: 0.4546 - accuracy: 0.8250

  18/1875 [..............................] - ETA: 11s - loss: 0.4349 - accuracy: 0.8351

  27/1875 [..............................] - ETA: 11s - loss: 0.4353 - accuracy: 0.8391

  35/1875 [..............................] - ETA: 11s - loss: 0.4150 - accuracy: 0.8464

  43/1875 [..............................] - ETA: 11s - loss: 0.4317 - accuracy: 0.8452

  51/1875 [..............................] - ETA: 11s - loss: 0.4179 - accuracy: 0.8517

  59/1875 [..............................] - ETA: 11s - loss: 0.4145 - accuracy: 0.8549

  67/1875 [>.............................] - ETA: 11s - loss: 0.4155 - accuracy: 0.8521

  76/1875 [>.............................] - ETA: 11s - loss: 0.4053 - accuracy: 0.8557

  85/1875 [>.............................] - ETA: 11s - loss: 0.4014 - accuracy: 0.8577

  94/1875 [>.............................] - ETA: 11s - loss: 0.3936 - accuracy: 0.8614

 103/1875 [>.............................] - ETA: 11s - loss: 0.3949 - accuracy: 0.8623

 111/1875 [>.............................] - ETA: 11s - loss: 0.4062 - accuracy: 0.8598

 119/1875 [>.............................] - ETA: 11s - loss: 0.4127 - accuracy: 0.8571

 127/1875 [=>............................] - ETA: 11s - loss: 0.4083 - accuracy: 0.8590

 135/1875 [=>............................] - ETA: 10s - loss: 0.4162 - accuracy: 0.8567

 143/1875 [=>............................] - ETA: 10s - loss: 0.4199 - accuracy: 0.8566

 151/1875 [=>............................] - ETA: 10s - loss: 0.4227 - accuracy: 0.8562

 159/1875 [=>............................] - ETA: 10s - loss: 0.4239 - accuracy: 0.8561

 167/1875 [=>............................] - ETA: 10s - loss: 0.4230 - accuracy: 0.8567

 175/1875 [=>............................] - ETA: 10s - loss: 0.4219 - accuracy: 0.8561

 183/1875 [=>............................] - ETA: 10s - loss: 0.4241 - accuracy: 0.8560

 191/1875 [==>...........................] - ETA: 10s - loss: 0.4213 - accuracy: 0.8573

 199/1875 [==>...........................] - ETA: 10s - loss: 0.4253 - accuracy: 0.8568

 207/1875 [==>...........................] - ETA: 10s - loss: 0.4272 - accuracy: 0.8557

 215/1875 [==>...........................] - ETA: 10s - loss: 0.4310 - accuracy: 0.8542

 223/1875 [==>...........................] - ETA: 10s - loss: 0.4307 - accuracy: 0.8538

 231/1875 [==>...........................] - ETA: 10s - loss: 0.4271 - accuracy: 0.8546

 239/1875 [==>...........................] - ETA: 10s - loss: 0.4253 - accuracy: 0.8550

 247/1875 [==>...........................] - ETA: 10s - loss: 0.4240 - accuracy: 0.8551

 255/1875 [===>..........................] - ETA: 10s - loss: 0.4230 - accuracy: 0.8554

 263/1875 [===>..........................] - ETA: 10s - loss: 0.4243 - accuracy: 0.8538

 271/1875 [===>..........................] - ETA: 10s - loss: 0.4256 - accuracy: 0.8533

 279/1875 [===>..........................] - ETA: 10s - loss: 0.4300 - accuracy: 0.8526

 287/1875 [===>..........................] - ETA: 10s - loss: 0.4270 - accuracy: 0.8527

 295/1875 [===>..........................] - ETA: 10s - loss: 0.4260 - accuracy: 0.8529

 303/1875 [===>..........................] - ETA: 10s - loss: 0.4290 - accuracy: 0.8519

 311/1875 [===>..........................] - ETA: 10s - loss: 0.4280 - accuracy: 0.8520

 319/1875 [====>.........................] - ETA: 10s - loss: 0.4293 - accuracy: 0.8523

 327/1875 [====>.........................] - ETA: 10s - loss: 0.4309 - accuracy: 0.8522

 335/1875 [====>.........................] - ETA: 10s - loss: 0.4284 - accuracy: 0.8531

 343/1875 [====>.........................] - ETA: 9s - loss: 0.4282 - accuracy: 0.8530 

 351/1875 [====>.........................] - ETA: 9s - loss: 0.4296 - accuracy: 0.8520

 359/1875 [====>.........................] - ETA: 9s - loss: 0.4284 - accuracy: 0.8525

 367/1875 [====>.........................] - ETA: 9s - loss: 0.4288 - accuracy: 0.8525

 375/1875 [=====>........................] - ETA: 9s - loss: 0.4307 - accuracy: 0.8525

 383/1875 [=====>........................] - ETA: 9s - loss: 0.4295 - accuracy: 0.8532

 391/1875 [=====>........................] - ETA: 9s - loss: 0.4275 - accuracy: 0.8545

 399/1875 [=====>........................] - ETA: 9s - loss: 0.4271 - accuracy: 0.8541

 407/1875 [=====>........................] - ETA: 9s - loss: 0.4277 - accuracy: 0.8537

 415/1875 [=====>........................] - ETA: 9s - loss: 0.4276 - accuracy: 0.8542

 423/1875 [=====>........................] - ETA: 9s - loss: 0.4282 - accuracy: 0.8542

 431/1875 [=====>........................] - ETA: 9s - loss: 0.4287 - accuracy: 0.8538

 439/1875 [======>.......................] - ETA: 9s - loss: 0.4286 - accuracy: 0.8541

 447/1875 [======>.......................] - ETA: 9s - loss: 0.4274 - accuracy: 0.8543

 455/1875 [======>.......................] - ETA: 9s - loss: 0.4265 - accuracy: 0.8546

 463/1875 [======>.......................] - ETA: 9s - loss: 0.4264 - accuracy: 0.8544

 471/1875 [======>.......................] - ETA: 9s - loss: 0.4273 - accuracy: 0.8539

 479/1875 [======>.......................] - ETA: 9s - loss: 0.4261 - accuracy: 0.8538

 487/1875 [======>.......................] - ETA: 9s - loss: 0.4244 - accuracy: 0.8545

 495/1875 [======>.......................] - ETA: 8s - loss: 0.4253 - accuracy: 0.8544

 503/1875 [=======>......................] - ETA: 8s - loss: 0.4239 - accuracy: 0.8548

 511/1875 [=======>......................] - ETA: 8s - loss: 0.4231 - accuracy: 0.8552

 519/1875 [=======>......................] - ETA: 8s - loss: 0.4221 - accuracy: 0.8554

 527/1875 [=======>......................] - ETA: 8s - loss: 0.4226 - accuracy: 0.8556

 535/1875 [=======>......................] - ETA: 8s - loss: 0.4222 - accuracy: 0.8554

 543/1875 [=======>......................] - ETA: 8s - loss: 0.4222 - accuracy: 0.8554

 551/1875 [=======>......................] - ETA: 8s - loss: 0.4232 - accuracy: 0.8549

 559/1875 [=======>......................] - ETA: 8s - loss: 0.4223 - accuracy: 0.8550

 567/1875 [========>.....................] - ETA: 8s - loss: 0.4222 - accuracy: 0.8549

 575/1875 [========>.....................] - ETA: 8s - loss: 0.4227 - accuracy: 0.8548

 583/1875 [========>.....................] - ETA: 8s - loss: 0.4230 - accuracy: 0.8548

 591/1875 [========>.....................] - ETA: 8s - loss: 0.4235 - accuracy: 0.8544

 599/1875 [========>.....................] - ETA: 8s - loss: 0.4226 - accuracy: 0.8548

 607/1875 [========>.....................] - ETA: 8s - loss: 0.4215 - accuracy: 0.8554

 615/1875 [========>.....................] - ETA: 8s - loss: 0.4209 - accuracy: 0.8553

 623/1875 [========>.....................] - ETA: 8s - loss: 0.4222 - accuracy: 0.8549

 631/1875 [=========>....................] - ETA: 8s - loss: 0.4215 - accuracy: 0.8551

 639/1875 [=========>....................] - ETA: 8s - loss: 0.4224 - accuracy: 0.8553

 647/1875 [=========>....................] - ETA: 7s - loss: 0.4225 - accuracy: 0.8554

 655/1875 [=========>....................] - ETA: 7s - loss: 0.4227 - accuracy: 0.8553

 663/1875 [=========>....................] - ETA: 7s - loss: 0.4209 - accuracy: 0.8560

 671/1875 [=========>....................] - ETA: 7s - loss: 0.4210 - accuracy: 0.8556

 679/1875 [=========>....................] - ETA: 7s - loss: 0.4196 - accuracy: 0.8557

 687/1875 [=========>....................] - ETA: 7s - loss: 0.4194 - accuracy: 0.8556

 695/1875 [==========>...................] - ETA: 7s - loss: 0.4211 - accuracy: 0.8555

 703/1875 [==========>...................] - ETA: 7s - loss: 0.4206 - accuracy: 0.8554

 711/1875 [==========>...................] - ETA: 7s - loss: 0.4209 - accuracy: 0.8551

 719/1875 [==========>...................] - ETA: 7s - loss: 0.4201 - accuracy: 0.8554

 727/1875 [==========>...................] - ETA: 7s - loss: 0.4186 - accuracy: 0.8558

 735/1875 [==========>...................] - ETA: 7s - loss: 0.4186 - accuracy: 0.8558

 743/1875 [==========>...................] - ETA: 7s - loss: 0.4181 - accuracy: 0.8560

 751/1875 [===========>..................] - ETA: 7s - loss: 0.4185 - accuracy: 0.8558

 759/1875 [===========>..................] - ETA: 7s - loss: 0.4193 - accuracy: 0.8556

 767/1875 [===========>..................] - ETA: 7s - loss: 0.4181 - accuracy: 0.8561

 775/1875 [===========>..................] - ETA: 7s - loss: 0.4176 - accuracy: 0.8559

 783/1875 [===========>..................] - ETA: 7s - loss: 0.4179 - accuracy: 0.8560

 790/1875 [===========>..................] - ETA: 7s - loss: 0.4171 - accuracy: 0.8562

 798/1875 [===========>..................] - ETA: 7s - loss: 0.4163 - accuracy: 0.8566

 806/1875 [===========>..................] - ETA: 6s - loss: 0.4161 - accuracy: 0.8567

 814/1875 [============>.................] - ETA: 6s - loss: 0.4163 - accuracy: 0.8566

 822/1875 [============>.................] - ETA: 6s - loss: 0.4163 - accuracy: 0.8566

 830/1875 [============>.................] - ETA: 6s - loss: 0.4160 - accuracy: 0.8566

 838/1875 [============>.................] - ETA: 6s - loss: 0.4167 - accuracy: 0.8565

 846/1875 [============>.................] - ETA: 6s - loss: 0.4165 - accuracy: 0.8566

 854/1875 [============>.................] - ETA: 6s - loss: 0.4159 - accuracy: 0.8568

 862/1875 [============>.................] - ETA: 6s - loss: 0.4168 - accuracy: 0.8565

 870/1875 [============>.................] - ETA: 6s - loss: 0.4173 - accuracy: 0.8563

 878/1875 [=============>................] - ETA: 6s - loss: 0.4176 - accuracy: 0.8562

 886/1875 [=============>................] - ETA: 6s - loss: 0.4172 - accuracy: 0.8565

 894/1875 [=============>................] - ETA: 6s - loss: 0.4173 - accuracy: 0.8564

 902/1875 [=============>................] - ETA: 6s - loss: 0.4169 - accuracy: 0.8563

 910/1875 [=============>................] - ETA: 6s - loss: 0.4176 - accuracy: 0.8561

 918/1875 [=============>................] - ETA: 6s - loss: 0.4183 - accuracy: 0.8562

 926/1875 [=============>................] - ETA: 6s - loss: 0.4179 - accuracy: 0.8565

 933/1875 [=============>................] - ETA: 6s - loss: 0.4176 - accuracy: 0.8564

 941/1875 [==============>...............] - ETA: 6s - loss: 0.4178 - accuracy: 0.8563

 949/1875 [==============>...............] - ETA: 6s - loss: 0.4180 - accuracy: 0.8560

 957/1875 [==============>...............] - ETA: 5s - loss: 0.4182 - accuracy: 0.8556

 965/1875 [==============>...............] - ETA: 5s - loss: 0.4182 - accuracy: 0.8559

 973/1875 [==============>...............] - ETA: 5s - loss: 0.4179 - accuracy: 0.8559

 981/1875 [==============>...............] - ETA: 5s - loss: 0.4173 - accuracy: 0.8560

 989/1875 [==============>...............] - ETA: 5s - loss: 0.4186 - accuracy: 0.8559

 997/1875 [==============>...............] - ETA: 5s - loss: 0.4190 - accuracy: 0.8558

1005/1875 [===============>..............] - ETA: 5s - loss: 0.4186 - accuracy: 0.8560

1013/1875 [===============>..............] - ETA: 5s - loss: 0.4182 - accuracy: 0.8563

1021/1875 [===============>..............] - ETA: 5s - loss: 0.4180 - accuracy: 0.8562

1029/1875 [===============>..............] - ETA: 5s - loss: 0.4181 - accuracy: 0.8561

1037/1875 [===============>..............] - ETA: 5s - loss: 0.4174 - accuracy: 0.8564

1045/1875 [===============>..............] - ETA: 5s - loss: 0.4171 - accuracy: 0.8567

1053/1875 [===============>..............] - ETA: 5s - loss: 0.4167 - accuracy: 0.8568

1061/1875 [===============>..............] - ETA: 5s - loss: 0.4175 - accuracy: 0.8566

1069/1875 [================>.............] - ETA: 5s - loss: 0.4176 - accuracy: 0.8566

1077/1875 [================>.............] - ETA: 5s - loss: 0.4181 - accuracy: 0.8563

1085/1875 [================>.............] - ETA: 5s - loss: 0.4182 - accuracy: 0.8562

1093/1875 [================>.............] - ETA: 5s - loss: 0.4184 - accuracy: 0.8560

1101/1875 [================>.............] - ETA: 5s - loss: 0.4184 - accuracy: 0.8560

1109/1875 [================>.............] - ETA: 4s - loss: 0.4179 - accuracy: 0.8563

1117/1875 [================>.............] - ETA: 4s - loss: 0.4180 - accuracy: 0.8561

1125/1875 [=================>............] - ETA: 4s - loss: 0.4183 - accuracy: 0.8561

1133/1875 [=================>............] - ETA: 4s - loss: 0.4181 - accuracy: 0.8562

1141/1875 [=================>............] - ETA: 4s - loss: 0.4177 - accuracy: 0.8563

1149/1875 [=================>............] - ETA: 4s - loss: 0.4173 - accuracy: 0.8563

1157/1875 [=================>............] - ETA: 4s - loss: 0.4168 - accuracy: 0.8566

1160/1875 [=================>............] - ETA: 34s - loss: 0.4168 - accuracy: 0.8565

1166/1875 [=================>............] - ETA: 33s - loss: 0.4170 - accuracy: 0.8566

1173/1875 [=================>............] - ETA: 33s - loss: 0.4167 - accuracy: 0.8568

1181/1875 [=================>............] - ETA: 32s - loss: 0.4164 - accuracy: 0.8569

1189/1875 [==================>...........] - ETA: 31s - loss: 0.4166 - accuracy: 0.8570

1197/1875 [==================>...........] - ETA: 31s - loss: 0.4163 - accuracy: 0.8570

1205/1875 [==================>...........] - ETA: 30s - loss: 0.4163 - accuracy: 0.8570

1212/1875 [==================>...........] - ETA: 30s - loss: 0.4170 - accuracy: 0.8567

1219/1875 [==================>...........] - ETA: 29s - loss: 0.4175 - accuracy: 0.8566

1227/1875 [==================>...........] - ETA: 29s - loss: 0.4179 - accuracy: 0.8566

1235/1875 [==================>...........] - ETA: 28s - loss: 0.4170 - accuracy: 0.8566

1243/1875 [==================>...........] - ETA: 28s - loss: 0.4177 - accuracy: 0.8563

1251/1875 [===================>..........] - ETA: 27s - loss: 0.4182 - accuracy: 0.8563

1259/1875 [===================>..........] - ETA: 27s - loss: 0.4183 - accuracy: 0.8560

1267/1875 [===================>..........] - ETA: 26s - loss: 0.4184 - accuracy: 0.8558

1275/1875 [===================>..........] - ETA: 26s - loss: 0.4188 - accuracy: 0.8558

1283/1875 [===================>..........] - ETA: 25s - loss: 0.4182 - accuracy: 0.8557

1291/1875 [===================>..........] - ETA: 25s - loss: 0.4191 - accuracy: 0.8557

1299/1875 [===================>..........] - ETA: 24s - loss: 0.4186 - accuracy: 0.8556

1307/1875 [===================>..........] - ETA: 24s - loss: 0.4186 - accuracy: 0.8558

1315/1875 [====================>.........] - ETA: 23s - loss: 0.4197 - accuracy: 0.8555

1322/1875 [====================>.........] - ETA: 23s - loss: 0.4197 - accuracy: 0.8555

1330/1875 [====================>.........] - ETA: 23s - loss: 0.4194 - accuracy: 0.8555

1338/1875 [====================>.........] - ETA: 22s - loss: 0.4189 - accuracy: 0.8555

1346/1875 [====================>.........] - ETA: 22s - loss: 0.4197 - accuracy: 0.8551

1353/1875 [====================>.........] - ETA: 21s - loss: 0.4197 - accuracy: 0.8552

1361/1875 [====================>.........] - ETA: 21s - loss: 0.4198 - accuracy: 0.8552

1369/1875 [====================>.........] - ETA: 20s - loss: 0.4191 - accuracy: 0.8554

1377/1875 [=====================>........] - ETA: 20s - loss: 0.4199 - accuracy: 0.8552

1385/1875 [=====================>........] - ETA: 20s - loss: 0.4198 - accuracy: 0.8554

1393/1875 [=====================>........] - ETA: 19s - loss: 0.4195 - accuracy: 0.8555

1400/1875 [=====================>........] - ETA: 19s - loss: 0.4193 - accuracy: 0.8555

1408/1875 [=====================>........] - ETA: 18s - loss: 0.4195 - accuracy: 0.8554

1416/1875 [=====================>........] - ETA: 18s - loss: 0.4192 - accuracy: 0.8556

1424/1875 [=====================>........] - ETA: 18s - loss: 0.4201 - accuracy: 0.8555

1431/1875 [=====================>........] - ETA: 17s - loss: 0.4199 - accuracy: 0.8555

1439/1875 [======================>.......] - ETA: 17s - loss: 0.4197 - accuracy: 0.8555

1447/1875 [======================>.......] - ETA: 16s - loss: 0.4206 - accuracy: 0.8556

1455/1875 [======================>.......] - ETA: 16s - loss: 0.4204 - accuracy: 0.8556

1463/1875 [======================>.......] - ETA: 16s - loss: 0.4200 - accuracy: 0.8557

1471/1875 [======================>.......] - ETA: 15s - loss: 0.4206 - accuracy: 0.8556

1479/1875 [======================>.......] - ETA: 15s - loss: 0.4204 - accuracy: 0.8556

1487/1875 [======================>.......] - ETA: 14s - loss: 0.4204 - accuracy: 0.8556

1495/1875 [======================>.......] - ETA: 14s - loss: 0.4197 - accuracy: 0.8557

1503/1875 [=======================>......] - ETA: 14s - loss: 0.4204 - accuracy: 0.8555

1511/1875 [=======================>......] - ETA: 13s - loss: 0.4209 - accuracy: 0.8553

1519/1875 [=======================>......] - ETA: 13s - loss: 0.4212 - accuracy: 0.8552

1527/1875 [=======================>......] - ETA: 13s - loss: 0.4208 - accuracy: 0.8552

1535/1875 [=======================>......] - ETA: 12s - loss: 0.4205 - accuracy: 0.8553

1543/1875 [=======================>......] - ETA: 12s - loss: 0.4208 - accuracy: 0.8552

1551/1875 [=======================>......] - ETA: 12s - loss: 0.4205 - accuracy: 0.8552

1559/1875 [=======================>......] - ETA: 11s - loss: 0.4209 - accuracy: 0.8552

1567/1875 [========================>.....] - ETA: 11s - loss: 0.4212 - accuracy: 0.8553

1575/1875 [========================>.....] - ETA: 11s - loss: 0.4208 - accuracy: 0.8553

1583/1875 [========================>.....] - ETA: 10s - loss: 0.4202 - accuracy: 0.8556

1591/1875 [========================>.....] - ETA: 10s - loss: 0.4203 - accuracy: 0.8554

1599/1875 [========================>.....] - ETA: 10s - loss: 0.4198 - accuracy: 0.8555

1607/1875 [========================>.....] - ETA: 9s - loss: 0.4204 - accuracy: 0.8555 

1615/1875 [========================>.....] - ETA: 9s - loss: 0.4208 - accuracy: 0.8554

1623/1875 [========================>.....] - ETA: 9s - loss: 0.4207 - accuracy: 0.8555

1631/1875 [=========================>....] - ETA: 8s - loss: 0.4203 - accuracy: 0.8555

1639/1875 [=========================>....] - ETA: 8s - loss: 0.4201 - accuracy: 0.8555

1647/1875 [=========================>....] - ETA: 8s - loss: 0.4204 - accuracy: 0.8554

1654/1875 [=========================>....] - ETA: 7s - loss: 0.4203 - accuracy: 0.8555

1662/1875 [=========================>....] - ETA: 7s - loss: 0.4199 - accuracy: 0.8557

1670/1875 [=========================>....] - ETA: 7s - loss: 0.4197 - accuracy: 0.8557

1678/1875 [=========================>....] - ETA: 6s - loss: 0.4197 - accuracy: 0.8557

1686/1875 [=========================>....] - ETA: 6s - loss: 0.4196 - accuracy: 0.8557

1694/1875 [==========================>...] - ETA: 6s - loss: 0.4196 - accuracy: 0.8556

1702/1875 [==========================>...] - ETA: 5s - loss: 0.4195 - accuracy: 0.8556

1710/1875 [==========================>...] - ETA: 5s - loss: 0.4191 - accuracy: 0.8558

1718/1875 [==========================>...] - ETA: 5s - loss: 0.4190 - accuracy: 0.8558

1726/1875 [==========================>...] - ETA: 5s - loss: 0.4186 - accuracy: 0.8559

1734/1875 [==========================>...] - ETA: 4s - loss: 0.4186 - accuracy: 0.8558

1743/1875 [==========================>...] - ETA: 4s - loss: 0.4185 - accuracy: 0.8557

1751/1875 [===========================>..] - ETA: 4s - loss: 0.4184 - accuracy: 0.8557

1759/1875 [===========================>..] - ETA: 3s - loss: 0.4180 - accuracy: 0.8557

1767/1875 [===========================>..] - ETA: 3s - loss: 0.4180 - accuracy: 0.8556

1775/1875 [===========================>..] - ETA: 3s - loss: 0.4183 - accuracy: 0.8557

1783/1875 [===========================>..] - ETA: 3s - loss: 0.4183 - accuracy: 0.8557

1791/1875 [===========================>..] - ETA: 2s - loss: 0.4184 - accuracy: 0.8557

1799/1875 [===========================>..] - ETA: 2s - loss: 0.4185 - accuracy: 0.8556

1807/1875 [===========================>..] - ETA: 2s - loss: 0.4188 - accuracy: 0.8556

1815/1875 [============================>.] - ETA: 1s - loss: 0.4191 - accuracy: 0.8555

1823/1875 [============================>.] - ETA: 1s - loss: 0.4192 - accuracy: 0.8555

1831/1875 [============================>.] - ETA: 1s - loss: 0.4192 - accuracy: 0.8554

1839/1875 [============================>.] - ETA: 1s - loss: 0.4185 - accuracy: 0.8556

1847/1875 [============================>.] - ETA: 0s - loss: 0.4177 - accuracy: 0.8558

1855/1875 [============================>.] - ETA: 0s - loss: 0.4179 - accuracy: 0.8557

1863/1875 [============================>.] - ETA: 0s - loss: 0.4184 - accuracy: 0.8556

1871/1875 [============================>.] - ETA: 0s - loss: 0.4184 - accuracy: 0.8556

1875/1875 [==============================] - 60s 32ms/step - loss: 0.4183 - accuracy: 0.8556
Epoch 3/3
   1/1875 [..............................] - ETA: 13s - loss: 0.4762 - accuracy: 0.8438

   9/1875 [..............................] - ETA: 12s - loss: 0.3872 - accuracy: 0.8819

  17/1875 [..............................] - ETA: 12s - loss: 0.4069 - accuracy: 0.8695

  25/1875 [..............................] - ETA: 12s - loss: 0.3952 - accuracy: 0.8625

  33/1875 [..............................] - ETA: 12s - loss: 0.3946 - accuracy: 0.8627

  41/1875 [..............................] - ETA: 12s - loss: 0.4041 - accuracy: 0.8575

  49/1875 [..............................] - ETA: 12s - loss: 0.3877 - accuracy: 0.8648

  57/1875 [..............................] - ETA: 12s - loss: 0.3806 - accuracy: 0.8668

  65/1875 [>.............................] - ETA: 12s - loss: 0.3942 - accuracy: 0.8654

  73/1875 [>.............................] - ETA: 11s - loss: 0.3962 - accuracy: 0.8660

  81/1875 [>.............................] - ETA: 11s - loss: 0.3860 - accuracy: 0.8692

  89/1875 [>.............................] - ETA: 11s - loss: 0.3862 - accuracy: 0.8697

  97/1875 [>.............................] - ETA: 11s - loss: 0.3872 - accuracy: 0.8679

 105/1875 [>.............................] - ETA: 11s - loss: 0.3864 - accuracy: 0.8682

 113/1875 [>.............................] - ETA: 11s - loss: 0.3892 - accuracy: 0.8673

 121/1875 [>.............................] - ETA: 11s - loss: 0.3910 - accuracy: 0.8665

 129/1875 [=>............................] - ETA: 11s - loss: 0.3909 - accuracy: 0.8668

 137/1875 [=>............................] - ETA: 11s - loss: 0.3953 - accuracy: 0.8638

 145/1875 [=>............................] - ETA: 11s - loss: 0.3908 - accuracy: 0.8653

 153/1875 [=>............................] - ETA: 11s - loss: 0.3887 - accuracy: 0.8664

 161/1875 [=>............................] - ETA: 11s - loss: 0.3859 - accuracy: 0.8667

 169/1875 [=>............................] - ETA: 11s - loss: 0.3817 - accuracy: 0.8674

 177/1875 [=>............................] - ETA: 11s - loss: 0.3833 - accuracy: 0.8665

 185/1875 [=>............................] - ETA: 11s - loss: 0.3811 - accuracy: 0.8671

 193/1875 [==>...........................] - ETA: 11s - loss: 0.3773 - accuracy: 0.8687

 201/1875 [==>...........................] - ETA: 10s - loss: 0.3760 - accuracy: 0.8696

 209/1875 [==>...........................] - ETA: 10s - loss: 0.3798 - accuracy: 0.8690

 217/1875 [==>...........................] - ETA: 10s - loss: 0.3843 - accuracy: 0.8681

 225/1875 [==>...........................] - ETA: 10s - loss: 0.3807 - accuracy: 0.8694

 233/1875 [==>...........................] - ETA: 10s - loss: 0.3836 - accuracy: 0.8695

 241/1875 [==>...........................] - ETA: 10s - loss: 0.3863 - accuracy: 0.8692

 249/1875 [==>...........................] - ETA: 10s - loss: 0.3859 - accuracy: 0.8691

 257/1875 [===>..........................] - ETA: 10s - loss: 0.3830 - accuracy: 0.8695

 265/1875 [===>..........................] - ETA: 10s - loss: 0.3873 - accuracy: 0.8704

 273/1875 [===>..........................] - ETA: 10s - loss: 0.3887 - accuracy: 0.8703

 281/1875 [===>..........................] - ETA: 10s - loss: 0.3878 - accuracy: 0.8700

 289/1875 [===>..........................] - ETA: 10s - loss: 0.3866 - accuracy: 0.8700

 297/1875 [===>..........................] - ETA: 10s - loss: 0.3900 - accuracy: 0.8695

 305/1875 [===>..........................] - ETA: 10s - loss: 0.3886 - accuracy: 0.8701

 313/1875 [====>.........................] - ETA: 10s - loss: 0.3881 - accuracy: 0.8699

 321/1875 [====>.........................] - ETA: 10s - loss: 0.3893 - accuracy: 0.8698

 329/1875 [====>.........................] - ETA: 10s - loss: 0.3875 - accuracy: 0.8705

 337/1875 [====>.........................] - ETA: 10s - loss: 0.3867 - accuracy: 0.8705

 345/1875 [====>.........................] - ETA: 9s - loss: 0.3862 - accuracy: 0.8699 

 353/1875 [====>.........................] - ETA: 9s - loss: 0.3871 - accuracy: 0.8692

 362/1875 [====>.........................] - ETA: 9s - loss: 0.3867 - accuracy: 0.8688

 370/1875 [====>.........................] - ETA: 9s - loss: 0.3863 - accuracy: 0.8692

 378/1875 [=====>........................] - ETA: 9s - loss: 0.3881 - accuracy: 0.8690

 386/1875 [=====>........................] - ETA: 9s - loss: 0.3917 - accuracy: 0.8678

 394/1875 [=====>........................] - ETA: 9s - loss: 0.3908 - accuracy: 0.8675

 402/1875 [=====>........................] - ETA: 9s - loss: 0.3938 - accuracy: 0.8673

 410/1875 [=====>........................] - ETA: 9s - loss: 0.3929 - accuracy: 0.8673

 418/1875 [=====>........................] - ETA: 9s - loss: 0.3925 - accuracy: 0.8677

 426/1875 [=====>........................] - ETA: 9s - loss: 0.3928 - accuracy: 0.8681

 434/1875 [=====>........................] - ETA: 9s - loss: 0.3915 - accuracy: 0.8686

 442/1875 [======>.......................] - ETA: 9s - loss: 0.3919 - accuracy: 0.8684

 450/1875 [======>.......................] - ETA: 9s - loss: 0.3931 - accuracy: 0.8681

 458/1875 [======>.......................] - ETA: 9s - loss: 0.3929 - accuracy: 0.8682

 466/1875 [======>.......................] - ETA: 9s - loss: 0.3933 - accuracy: 0.8678

 474/1875 [======>.......................] - ETA: 9s - loss: 0.3925 - accuracy: 0.8681

 482/1875 [======>.......................] - ETA: 9s - loss: 0.3929 - accuracy: 0.8681

 490/1875 [======>.......................] - ETA: 9s - loss: 0.3921 - accuracy: 0.8682

 498/1875 [======>.......................] - ETA: 8s - loss: 0.3915 - accuracy: 0.8685

 506/1875 [=======>......................] - ETA: 8s - loss: 0.3907 - accuracy: 0.8686

 514/1875 [=======>......................] - ETA: 8s - loss: 0.3911 - accuracy: 0.8689

 522/1875 [=======>......................] - ETA: 8s - loss: 0.3922 - accuracy: 0.8681

 530/1875 [=======>......................] - ETA: 8s - loss: 0.3933 - accuracy: 0.8676

 538/1875 [=======>......................] - ETA: 8s - loss: 0.3924 - accuracy: 0.8682

 546/1875 [=======>......................] - ETA: 8s - loss: 0.3926 - accuracy: 0.8682

 554/1875 [=======>......................] - ETA: 8s - loss: 0.3938 - accuracy: 0.8679

 562/1875 [=======>......................] - ETA: 8s - loss: 0.3933 - accuracy: 0.8677

 569/1875 [========>.....................] - ETA: 8s - loss: 0.3951 - accuracy: 0.8676

 577/1875 [========>.....................] - ETA: 8s - loss: 0.3938 - accuracy: 0.8679

 585/1875 [========>.....................] - ETA: 8s - loss: 0.3943 - accuracy: 0.8678

 593/1875 [========>.....................] - ETA: 8s - loss: 0.3943 - accuracy: 0.8679

 601/1875 [========>.....................] - ETA: 8s - loss: 0.3951 - accuracy: 0.8678

 609/1875 [========>.....................] - ETA: 8s - loss: 0.3958 - accuracy: 0.8676

 617/1875 [========>.....................] - ETA: 8s - loss: 0.3946 - accuracy: 0.8680

 625/1875 [=========>....................] - ETA: 8s - loss: 0.3947 - accuracy: 0.8679

 633/1875 [=========>....................] - ETA: 8s - loss: 0.3933 - accuracy: 0.8681

 641/1875 [=========>....................] - ETA: 8s - loss: 0.3931 - accuracy: 0.8678

 649/1875 [=========>....................] - ETA: 8s - loss: 0.3937 - accuracy: 0.8672

 657/1875 [=========>....................] - ETA: 7s - loss: 0.3933 - accuracy: 0.8669

 665/1875 [=========>....................] - ETA: 7s - loss: 0.3927 - accuracy: 0.8669

 673/1875 [=========>....................] - ETA: 7s - loss: 0.3933 - accuracy: 0.8668

 681/1875 [=========>....................] - ETA: 7s - loss: 0.3941 - accuracy: 0.8666

 689/1875 [==========>...................] - ETA: 7s - loss: 0.3951 - accuracy: 0.8665

 697/1875 [==========>...................] - ETA: 7s - loss: 0.3950 - accuracy: 0.8662

 705/1875 [==========>...................] - ETA: 7s - loss: 0.3949 - accuracy: 0.8664

 713/1875 [==========>...................] - ETA: 7s - loss: 0.3947 - accuracy: 0.8661

 721/1875 [==========>...................] - ETA: 7s - loss: 0.3953 - accuracy: 0.8660

 729/1875 [==========>...................] - ETA: 7s - loss: 0.3945 - accuracy: 0.8664

 737/1875 [==========>...................] - ETA: 7s - loss: 0.3944 - accuracy: 0.8665

 745/1875 [==========>...................] - ETA: 7s - loss: 0.3948 - accuracy: 0.8662

 753/1875 [===========>..................] - ETA: 7s - loss: 0.3943 - accuracy: 0.8665

 761/1875 [===========>..................] - ETA: 7s - loss: 0.3953 - accuracy: 0.8660

 769/1875 [===========>..................] - ETA: 7s - loss: 0.3969 - accuracy: 0.8656

 777/1875 [===========>..................] - ETA: 7s - loss: 0.3974 - accuracy: 0.8655

 785/1875 [===========>..................] - ETA: 7s - loss: 0.3983 - accuracy: 0.8652

 794/1875 [===========>..................] - ETA: 7s - loss: 0.3986 - accuracy: 0.8652

 802/1875 [===========>..................] - ETA: 6s - loss: 0.3979 - accuracy: 0.8654

 810/1875 [===========>..................] - ETA: 6s - loss: 0.3978 - accuracy: 0.8654

 818/1875 [============>.................] - ETA: 6s - loss: 0.3986 - accuracy: 0.8649

 826/1875 [============>.................] - ETA: 6s - loss: 0.3989 - accuracy: 0.8647

 834/1875 [============>.................] - ETA: 6s - loss: 0.3991 - accuracy: 0.8647

 842/1875 [============>.................] - ETA: 6s - loss: 0.3986 - accuracy: 0.8646

 850/1875 [============>.................] - ETA: 6s - loss: 0.3986 - accuracy: 0.8644

 858/1875 [============>.................] - ETA: 6s - loss: 0.3991 - accuracy: 0.8645

 866/1875 [============>.................] - ETA: 6s - loss: 0.3989 - accuracy: 0.8645

 874/1875 [============>.................] - ETA: 6s - loss: 0.3991 - accuracy: 0.8644

 882/1875 [=============>................] - ETA: 6s - loss: 0.3992 - accuracy: 0.8641

 890/1875 [=============>................] - ETA: 6s - loss: 0.3987 - accuracy: 0.8642

 898/1875 [=============>................] - ETA: 6s - loss: 0.3988 - accuracy: 0.8642

 906/1875 [=============>................] - ETA: 6s - loss: 0.3977 - accuracy: 0.8644

 914/1875 [=============>................] - ETA: 6s - loss: 0.3977 - accuracy: 0.8645

 922/1875 [=============>................] - ETA: 6s - loss: 0.3976 - accuracy: 0.8645

 930/1875 [=============>................] - ETA: 6s - loss: 0.3979 - accuracy: 0.8644

 938/1875 [==============>...............] - ETA: 6s - loss: 0.3976 - accuracy: 0.8643

 946/1875 [==============>...............] - ETA: 6s - loss: 0.3987 - accuracy: 0.8641

 954/1875 [==============>...............] - ETA: 5s - loss: 0.3989 - accuracy: 0.8642

 962/1875 [==============>...............] - ETA: 5s - loss: 0.3992 - accuracy: 0.8642

 970/1875 [==============>...............] - ETA: 5s - loss: 0.3994 - accuracy: 0.8643

 978/1875 [==============>...............] - ETA: 5s - loss: 0.3991 - accuracy: 0.8644

 986/1875 [==============>...............] - ETA: 5s - loss: 0.3988 - accuracy: 0.8644

 989/1875 [==============>...............] - ETA: 53s - loss: 0.3987 - accuracy: 0.8643

 993/1875 [==============>...............] - ETA: 52s - loss: 0.3983 - accuracy: 0.8643

1000/1875 [===============>..............] - ETA: 52s - loss: 0.3984 - accuracy: 0.8643

1008/1875 [===============>..............] - ETA: 51s - loss: 0.3977 - accuracy: 0.8645

1016/1875 [===============>..............] - ETA: 50s - loss: 0.3979 - accuracy: 0.8642

1024/1875 [===============>..............] - ETA: 49s - loss: 0.3979 - accuracy: 0.8642

1032/1875 [===============>..............] - ETA: 48s - loss: 0.3977 - accuracy: 0.8642

1039/1875 [===============>..............] - ETA: 48s - loss: 0.3977 - accuracy: 0.8641

1047/1875 [===============>..............] - ETA: 47s - loss: 0.3974 - accuracy: 0.8642

1055/1875 [===============>..............] - ETA: 46s - loss: 0.3977 - accuracy: 0.8640

1063/1875 [================>.............] - ETA: 45s - loss: 0.3983 - accuracy: 0.8638

1071/1875 [================>.............] - ETA: 45s - loss: 0.3980 - accuracy: 0.8637

1079/1875 [================>.............] - ETA: 44s - loss: 0.3977 - accuracy: 0.8638

1087/1875 [================>.............] - ETA: 43s - loss: 0.3981 - accuracy: 0.8638

1095/1875 [================>.............] - ETA: 42s - loss: 0.3977 - accuracy: 0.8638

1103/1875 [================>.............] - ETA: 42s - loss: 0.3982 - accuracy: 0.8635

1111/1875 [================>.............] - ETA: 41s - loss: 0.3997 - accuracy: 0.8634

1119/1875 [================>.............] - ETA: 40s - loss: 0.3994 - accuracy: 0.8635

1127/1875 [=================>............] - ETA: 40s - loss: 0.3994 - accuracy: 0.8635

1135/1875 [=================>............] - ETA: 39s - loss: 0.3995 - accuracy: 0.8634

1143/1875 [=================>............] - ETA: 38s - loss: 0.3991 - accuracy: 0.8636

1151/1875 [=================>............] - ETA: 38s - loss: 0.3994 - accuracy: 0.8635

1159/1875 [=================>............] - ETA: 37s - loss: 0.3986 - accuracy: 0.8636

1167/1875 [=================>............] - ETA: 36s - loss: 0.3984 - accuracy: 0.8636

1175/1875 [=================>............] - ETA: 36s - loss: 0.3980 - accuracy: 0.8636

1183/1875 [=================>............] - ETA: 35s - loss: 0.3975 - accuracy: 0.8637

1191/1875 [==================>...........] - ETA: 35s - loss: 0.3976 - accuracy: 0.8637

1199/1875 [==================>...........] - ETA: 34s - loss: 0.3980 - accuracy: 0.8637

1207/1875 [==================>...........] - ETA: 33s - loss: 0.3979 - accuracy: 0.8636

1215/1875 [==================>...........] - ETA: 33s - loss: 0.3978 - accuracy: 0.8636

1222/1875 [==================>...........] - ETA: 32s - loss: 0.3978 - accuracy: 0.8637

1230/1875 [==================>...........] - ETA: 32s - loss: 0.3978 - accuracy: 0.8638

1237/1875 [==================>...........] - ETA: 31s - loss: 0.3976 - accuracy: 0.8639

1245/1875 [==================>...........] - ETA: 31s - loss: 0.3971 - accuracy: 0.8640

1253/1875 [===================>..........] - ETA: 30s - loss: 0.3968 - accuracy: 0.8643

1261/1875 [===================>..........] - ETA: 29s - loss: 0.3968 - accuracy: 0.8643

1269/1875 [===================>..........] - ETA: 29s - loss: 0.3970 - accuracy: 0.8645

1276/1875 [===================>..........] - ETA: 28s - loss: 0.3978 - accuracy: 0.8643

1284/1875 [===================>..........] - ETA: 28s - loss: 0.3976 - accuracy: 0.8643

1292/1875 [===================>..........] - ETA: 27s - loss: 0.3975 - accuracy: 0.8643

1300/1875 [===================>..........] - ETA: 27s - loss: 0.3974 - accuracy: 0.8645

1308/1875 [===================>..........] - ETA: 26s - loss: 0.3973 - accuracy: 0.8646

1316/1875 [====================>.........] - ETA: 26s - loss: 0.3971 - accuracy: 0.8647

1324/1875 [====================>.........] - ETA: 25s - loss: 0.3969 - accuracy: 0.8647

1332/1875 [====================>.........] - ETA: 25s - loss: 0.3974 - accuracy: 0.8647

1340/1875 [====================>.........] - ETA: 24s - loss: 0.3976 - accuracy: 0.8646

1348/1875 [====================>.........] - ETA: 24s - loss: 0.3975 - accuracy: 0.8645

1356/1875 [====================>.........] - ETA: 23s - loss: 0.3982 - accuracy: 0.8645

1363/1875 [====================>.........] - ETA: 23s - loss: 0.3989 - accuracy: 0.8642

1371/1875 [====================>.........] - ETA: 22s - loss: 0.3987 - accuracy: 0.8643

1379/1875 [=====================>........] - ETA: 22s - loss: 0.3984 - accuracy: 0.8643

1387/1875 [=====================>........] - ETA: 21s - loss: 0.3982 - accuracy: 0.8644

1395/1875 [=====================>........] - ETA: 21s - loss: 0.3984 - accuracy: 0.8643

1403/1875 [=====================>........] - ETA: 21s - loss: 0.3984 - accuracy: 0.8644

1411/1875 [=====================>........] - ETA: 20s - loss: 0.3989 - accuracy: 0.8641

1419/1875 [=====================>........] - ETA: 20s - loss: 0.3987 - accuracy: 0.8640

1427/1875 [=====================>........] - ETA: 19s - loss: 0.3987 - accuracy: 0.8641

1435/1875 [=====================>........] - ETA: 19s - loss: 0.3994 - accuracy: 0.8639

1443/1875 [======================>.......] - ETA: 18s - loss: 0.3996 - accuracy: 0.8639

1451/1875 [======================>.......] - ETA: 18s - loss: 0.4005 - accuracy: 0.8638

1459/1875 [======================>.......] - ETA: 17s - loss: 0.4006 - accuracy: 0.8638

1467/1875 [======================>.......] - ETA: 17s - loss: 0.4013 - accuracy: 0.8635

1475/1875 [======================>.......] - ETA: 17s - loss: 0.4007 - accuracy: 0.8637

1483/1875 [======================>.......] - ETA: 16s - loss: 0.4005 - accuracy: 0.8638

1491/1875 [======================>.......] - ETA: 16s - loss: 0.4001 - accuracy: 0.8641

1499/1875 [======================>.......] - ETA: 15s - loss: 0.4001 - accuracy: 0.8640

1507/1875 [=======================>......] - ETA: 15s - loss: 0.4003 - accuracy: 0.8641

1515/1875 [=======================>......] - ETA: 15s - loss: 0.4002 - accuracy: 0.8642

1523/1875 [=======================>......] - ETA: 14s - loss: 0.4006 - accuracy: 0.8639

1531/1875 [=======================>......] - ETA: 14s - loss: 0.4010 - accuracy: 0.8638

1539/1875 [=======================>......] - ETA: 13s - loss: 0.4012 - accuracy: 0.8637

1547/1875 [=======================>......] - ETA: 13s - loss: 0.4015 - accuracy: 0.8637

1555/1875 [=======================>......] - ETA: 13s - loss: 0.4019 - accuracy: 0.8635

1563/1875 [========================>.....] - ETA: 12s - loss: 0.4015 - accuracy: 0.8636

1571/1875 [========================>.....] - ETA: 12s - loss: 0.4012 - accuracy: 0.8637

1579/1875 [========================>.....] - ETA: 11s - loss: 0.4007 - accuracy: 0.8639

1587/1875 [========================>.....] - ETA: 11s - loss: 0.4006 - accuracy: 0.8639

1595/1875 [========================>.....] - ETA: 11s - loss: 0.4001 - accuracy: 0.8639

1603/1875 [========================>.....] - ETA: 10s - loss: 0.3993 - accuracy: 0.8642

1612/1875 [========================>.....] - ETA: 10s - loss: 0.3991 - accuracy: 0.8642

1620/1875 [========================>.....] - ETA: 10s - loss: 0.3997 - accuracy: 0.8640

1628/1875 [=========================>....] - ETA: 9s - loss: 0.4001 - accuracy: 0.8641 

1636/1875 [=========================>....] - ETA: 9s - loss: 0.4000 - accuracy: 0.8641

1644/1875 [=========================>....] - ETA: 8s - loss: 0.3998 - accuracy: 0.8640

1652/1875 [=========================>....] - ETA: 8s - loss: 0.3993 - accuracy: 0.8641

1660/1875 [=========================>....] - ETA: 8s - loss: 0.3991 - accuracy: 0.8643

1668/1875 [=========================>....] - ETA: 7s - loss: 0.3988 - accuracy: 0.8642

1676/1875 [=========================>....] - ETA: 7s - loss: 0.3985 - accuracy: 0.8643

1684/1875 [=========================>....] - ETA: 7s - loss: 0.3987 - accuracy: 0.8641

1692/1875 [==========================>...] - ETA: 6s - loss: 0.3988 - accuracy: 0.8642

1701/1875 [==========================>...] - ETA: 6s - loss: 0.3988 - accuracy: 0.8641

1709/1875 [==========================>...] - ETA: 6s - loss: 0.3987 - accuracy: 0.8641

1717/1875 [==========================>...] - ETA: 5s - loss: 0.3986 - accuracy: 0.8642

1725/1875 [==========================>...] - ETA: 5s - loss: 0.3982 - accuracy: 0.8643

1733/1875 [==========================>...] - ETA: 5s - loss: 0.3985 - accuracy: 0.8643

1741/1875 [==========================>...] - ETA: 4s - loss: 0.3987 - accuracy: 0.8642

1749/1875 [==========================>...] - ETA: 4s - loss: 0.3990 - accuracy: 0.8641

1757/1875 [===========================>..] - ETA: 4s - loss: 0.3988 - accuracy: 0.8642

1765/1875 [===========================>..] - ETA: 4s - loss: 0.3985 - accuracy: 0.8644

1773/1875 [===========================>..] - ETA: 3s - loss: 0.3984 - accuracy: 0.8644

1780/1875 [===========================>..] - ETA: 3s - loss: 0.3980 - accuracy: 0.8645

1788/1875 [===========================>..] - ETA: 3s - loss: 0.3979 - accuracy: 0.8646

1796/1875 [===========================>..] - ETA: 2s - loss: 0.3979 - accuracy: 0.8645

1804/1875 [===========================>..] - ETA: 2s - loss: 0.3975 - accuracy: 0.8644

1812/1875 [===========================>..] - ETA: 2s - loss: 0.3977 - accuracy: 0.8644

1820/1875 [============================>.] - ETA: 1s - loss: 0.3977 - accuracy: 0.8645

1828/1875 [============================>.] - ETA: 1s - loss: 0.3975 - accuracy: 0.8646

1836/1875 [============================>.] - ETA: 1s - loss: 0.3978 - accuracy: 0.8645

1844/1875 [============================>.] - ETA: 1s - loss: 0.3980 - accuracy: 0.8644

1850/1875 [============================>.] - ETA: 0s - loss: 0.3978 - accuracy: 0.8644

1858/1875 [============================>.] - ETA: 0s - loss: 0.3976 - accuracy: 0.8645

1865/1875 [============================>.] - ETA: 0s - loss: 0.3975 - accuracy: 0.8645

1873/1875 [============================>.] - ETA: 0s - loss: 0.3980 - accuracy: 0.8644

1875/1875 [==============================] - 66s 35ms/step - loss: 0.3980 - accuracy: 0.8643

Predictions and evaluations

We can now call predict to generate predictions, and evaluate the trained model on the entire test set

network.predict(X_test)
test_loss, test_acc = network.evaluate(X_test, y_test)
np.set_printoptions(precision=7)
fig, axes = plt.subplots(1, 1, figsize=(2, 2))
sample_id = 4
axes.imshow(Xf_test[sample_id].reshape(28, 28), cmap=plt.cm.gray_r)
axes.set_xlabel("True label: {}".format(yf_test[sample_id]))
axes.set_xticks([])
axes.set_yticks([])
print(network.predict(Xf_test)[sample_id])
2022-03-16 12:45:53.112813: I tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.cc:112] Plugin optimizer for device_type GPU is enabled.
[0.007654  0.0000028 0.7707164 0.0006719 0.0179828 0.        0.2029032
 0.        0.0000689 0.       ]
../_images/08 - Neural Networks_77_2.png
test_loss, test_acc = network.evaluate(Xf_test, yf_test)
print('Test accuracy:', test_acc)
  1/313 [..............................] - ETA: 40s - loss: 0.2433 - accuracy: 0.9375

  6/313 [..............................] - ETA: 3s - loss: 0.4325 - accuracy: 0.8594 

 14/313 [>.............................] - ETA: 2s - loss: 0.3654 - accuracy: 0.8772
2022-03-16 12:45:53.782057: I tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.cc:112] Plugin optimizer for device_type GPU is enabled.

 20/313 [>.............................] - ETA: 2s - loss: 0.3526 - accuracy: 0.8766

 28/313 [=>............................] - ETA: 2s - loss: 0.3395 - accuracy: 0.8806

 36/313 [==>...........................] - ETA: 2s - loss: 0.3421 - accuracy: 0.8863

 44/313 [===>..........................] - ETA: 2s - loss: 0.3404 - accuracy: 0.8828

 52/313 [===>..........................] - ETA: 1s - loss: 0.3485 - accuracy: 0.8816

 61/313 [====>.........................] - ETA: 1s - loss: 0.3540 - accuracy: 0.8842

 69/313 [=====>........................] - ETA: 53:53 - loss: 0.3495 - accuracy: 0.8827

 74/313 [======>.......................] - ETA: 49:10 - loss: 0.3427 - accuracy: 0.8856

 81/313 [======>.......................] - ETA: 43:33 - loss: 0.3434 - accuracy: 0.8862

 89/313 [=======>......................] - ETA: 38:14 - loss: 0.3404 - accuracy: 0.8848

 98/313 [========>.....................] - ETA: 33:18 - loss: 0.3447 - accuracy: 0.8839

102/313 [========>.....................] - ETA: 31:23 - loss: 0.3465 - accuracy: 0.8830

110/313 [=========>....................] - ETA: 27:59 - loss: 0.3490 - accuracy: 0.8813

118/313 [==========>...................] - ETA: 25:02 - loss: 0.3547 - accuracy: 0.8827

126/313 [===========>..................] - ETA: 22:28 - loss: 0.3599 - accuracy: 0.8812

134/313 [===========>..................] - ETA: 20:13 - loss: 0.3588 - accuracy: 0.8806

140/313 [============>.................] - ETA: 18:42 - loss: 0.3590 - accuracy: 0.8804

148/313 [=============>................] - ETA: 16:52 - loss: 0.3610 - accuracy: 0.8801

155/313 [=============>................] - ETA: 15:25 - loss: 0.3621 - accuracy: 0.8786

161/313 [==============>...............] - ETA: 14:16 - loss: 0.3613 - accuracy: 0.8781

169/313 [===============>..............] - ETA: 12:53 - loss: 0.3683 - accuracy: 0.8776

178/313 [================>.............] - ETA: 11:28 - loss: 0.3682 - accuracy: 0.8771

187/313 [================>.............] - ETA: 10:11 - loss: 0.3678 - accuracy: 0.8782

196/313 [=================>............] - ETA: 9:01 - loss: 0.3663 - accuracy: 0.8777 

205/313 [==================>...........] - ETA: 7:57 - loss: 0.3627 - accuracy: 0.8784

214/313 [===================>..........] - ETA: 6:59 - loss: 0.3641 - accuracy: 0.8789

223/313 [====================>.........] - ETA: 6:05 - loss: 0.3620 - accuracy: 0.8793

232/313 [=====================>........] - ETA: 5:16 - loss: 0.3608 - accuracy: 0.8792

241/313 [======================>.......] - ETA: 4:30 - loss: 0.3646 - accuracy: 0.8784

251/313 [=======================>......] - ETA: 3:43 - loss: 0.3660 - accuracy: 0.8780

260/313 [=======================>......] - ETA: 3:04 - loss: 0.3645 - accuracy: 0.8784

269/313 [========================>.....] - ETA: 2:28 - loss: 0.3646 - accuracy: 0.8780

278/313 [=========================>....] - ETA: 1:54 - loss: 0.3651 - accuracy: 0.8779

287/313 [==========================>...] - ETA: 1:22 - loss: 0.3695 - accuracy: 0.8772

295/313 [===========================>..] - ETA: 55s - loss: 0.3671 - accuracy: 0.8775 

304/313 [============================>.] - ETA: 26s - loss: 0.3656 - accuracy: 0.8773

313/313 [==============================] - ETA: 0s - loss: 0.3672 - accuracy: 0.8764 

313/313 [==============================] - 903s 3s/step - loss: 0.3672 - accuracy: 0.8764
Test accuracy: 0.8764000535011292

Model selection

  • How many epochs do we need for training?

  • Train the neural net and track the loss after every iteration on a validation set

    • You can add a callback to the fit version to get info on every epoch

  • Best model after a few epochs, then starts overfitting

from tensorflow.keras.callbacks import Callback
from IPython.display import clear_output

# For plotting the learning curve in real time
class TrainingPlot(Callback):
    
    # This function is called when the training begins
    def on_train_begin(self, logs={}):
        # Initialize the lists for holding the logs, losses and accuracies
        self.losses = []
        self.acc = []
        self.val_losses = []
        self.val_acc = []
        self.logs = []
        self.max_acc = 0
    
    # This function is called at the end of each epoch
    def on_epoch_end(self, epoch, logs={}):
        
        # Append the logs, losses and accuracies to the lists
        self.logs.append(logs)
        self.losses.append(logs.get('loss'))
        self.acc.append(logs.get('accuracy'))
        self.val_losses.append(logs.get('val_loss'))
        self.val_acc.append(logs.get('val_accuracy'))
        self.max_acc = max(self.max_acc, logs.get('val_accuracy'))
        
        # Before plotting ensure at least 2 epochs have passed
        if len(self.losses) > 1:
            
            # Clear the previous plot
            clear_output(wait=True)
            N = np.arange(0, len(self.losses))
            
            # Plot train loss, train acc, val loss and val acc against epochs passed
            plt.figure(figsize=(8,3))
            plt.plot(N, self.losses, lw=2, c="b", linestyle="-", label = "train_loss")
            plt.plot(N, self.acc, lw=2, c="r", linestyle="-", label = "train_acc")
            plt.plot(N, self.val_losses, lw=2, c="b", linestyle=":", label = "val_loss")
            plt.plot(N, self.val_acc, lw=2, c="r", linestyle=":", label = "val_acc")
            plt.title("Training Loss and Accuracy [Epoch {}, Max Acc {:.4f}]".format(epoch, self.max_acc))
            plt.xlabel("Epoch #")
            plt.ylabel("Loss/Accuracy")
            plt.legend()
            plt.show()
from sklearn.model_selection import train_test_split

x_val, partial_x_train = Xf_train[:10000], Xf_train[10000:]
y_val, partial_y_train = yf_train[:10000], yf_train[10000:] 
network = models.Sequential()
network.add(layers.Dense(512, activation='relu', kernel_initializer='he_normal', input_shape=(28 * 28,)))
network.add(layers.Dense(512, activation='relu', kernel_initializer='he_normal'))
network.add(layers.Dense(10, activation='softmax'))
network.compile(loss='categorical_crossentropy', optimizer='rmsprop', metrics=['accuracy'])
plot_losses = TrainingPlot()
history = network.fit(partial_x_train, partial_y_train, epochs=25, batch_size=512, verbose=0,
                      validation_data=(x_val, y_val), callbacks=[plot_losses])
../_images/08 - Neural Networks_81_0.png

Early stopping

  • Stop training when the validation loss (or validation accuracy) no longer improves

  • Loss can be bumpy: use a moving average or wait for \(k\) steps without improvement

earlystop = callbacks.EarlyStopping(monitor='val_loss', patience=3)
model.fit(x_train, y_train, epochs=25, batch_size=512, callbacks=[earlystop])
from tensorflow.keras import callbacks

earlystop = callbacks.EarlyStopping(monitor='val_loss', patience=3)

network = models.Sequential()
network.add(layers.Dense(512, activation='relu', kernel_initializer='he_normal', input_shape=(28 * 28,)))
network.add(layers.Dense(512, activation='relu', kernel_initializer='he_normal'))
network.add(layers.Dense(10, activation='softmax'))
network.compile(loss='categorical_crossentropy', optimizer='rmsprop', metrics=['accuracy'])
plot_losses = TrainingPlot()
history = network.fit(partial_x_train, partial_y_train, epochs=25, batch_size=512, verbose=0,
                      validation_data=(x_val, y_val), callbacks=[plot_losses, earlystop])
../_images/08 - Neural Networks_83_0.png

Regularization and memorization capacity

  • The number of learnable parameters is called the model capacity

  • A model with more parameters has a higher memorization capacity

    • Too high capacity causes overfitting, too low causes underfitting

    • In the extreme, the training set can be ‘memorized’ in the weights

  • Smaller models are forced it to learn a compressed representation that generalizes better

    • Find the sweet spot: e.g. start with few parameters, increase until overfitting stars.

  • Example: 256 nodes in first layer, 32 nodes in second layer, similar performance

network = models.Sequential()
network.add(layers.Dense(256, activation='relu', kernel_initializer='he_normal', input_shape=(28 * 28,)))
network.add(layers.Dense(32, activation='relu', kernel_initializer='he_normal'))
network.add(layers.Dense(10, activation='softmax'))
network.compile(loss='categorical_crossentropy', optimizer='rmsprop', metrics=['accuracy'])
earlystop5 = callbacks.EarlyStopping(monitor='val_loss', patience=5)
plot_losses = TrainingPlot()
history = network.fit(partial_x_train, partial_y_train, epochs=30, batch_size=512, verbose=0,
                      validation_data=(x_val, y_val), callbacks=[plot_losses, earlystop5])
../_images/08 - Neural Networks_85_0.png

Information bottleneck

  • If a layer is too narrow, it will lose information that can never be recovered by subsequent layers

  • Information bottleneck theory defines a bound on the capacity of the network

  • Imagine that you need to learn 10 outputs (e.g. classes) and your hidden layer has 2 nodes

    • This is like trying to learn 10 hyperplanes from a 2-dimensional representation

  • Example: bottleneck of 2 nodes, no overfitting, much higher training loss

network = models.Sequential()
network.add(layers.Dense(256, activation='relu', kernel_initializer='he_normal', input_shape=(28 * 28,)))
network.add(layers.Dense(2, activation='relu', kernel_initializer='he_normal'))
network.add(layers.Dense(10, activation='softmax'))
network.compile(loss='categorical_crossentropy', optimizer='rmsprop', metrics=['accuracy'])
earlystop5 = callbacks.EarlyStopping(monitor='val_loss', patience=5)
plot_losses = TrainingPlot()
history = network.fit(partial_x_train, partial_y_train, epochs=30, batch_size=512, verbose=0,
                      validation_data=(x_val, y_val), callbacks=[plot_losses, earlystop5])
../_images/08 - Neural Networks_87_0.png

Weight regularization (weight decay)

  • As we did many times before, we can also add weight regularization to our loss function

  • L1 regularization: leads to sparse networks with many weights that are 0

  • L2 regularization: leads to many very small weights

network = models.Sequential()
network.add(layers.Dense(256, activation='relu', kernel_regularizer=regularizers.l2(0.001), input_shape=(28 * 28,)))
network.add(layers.Dense(128, activation='relu', kernel_regularizer=regularizers.l2(0.001)))
from tensorflow.keras import regularizers

network = models.Sequential()
network.add(layers.Dense(256, activation='relu', kernel_regularizer=regularizers.l2(0.001), input_shape=(28 * 28,)))
network.add(layers.Dense(32, activation='relu', kernel_regularizer=regularizers.l2(0.001)))
network.add(layers.Dense(10, activation='softmax'))
network.compile(loss='categorical_crossentropy', optimizer='rmsprop', metrics=['accuracy'])
earlystop5 = callbacks.EarlyStopping(monitor='val_loss', patience=5)
plot_losses = TrainingPlot()
history = network.fit(partial_x_train, partial_y_train, epochs=50, batch_size=512, verbose=0,
                      validation_data=(x_val, y_val), callbacks=[plot_losses, earlystop5])
../_images/08 - Neural Networks_90_0.png

Dropout

  • Every iteration, randomly set a number of activations \(a_i\) to 0

  • Dropout rate : fraction of the outputs that are zeroed-out (e.g. 0.1 - 0.5)

  • Idea: break up accidental non-significant learned patterns

  • At test time, nothing is dropped out, but the output values are scaled down by the dropout rate

    • Balances out that more units are active than during training

fig = plt.figure(figsize=(4, 4))
ax = fig.gca()
draw_neural_net(ax, [2, 3, 1], draw_bias=True, labels=True, 
                show_activations=True, activation=True)
../_images/08 - Neural Networks_92_0.png

Dropout layers

  • Dropout is usually implemented as a special layer

network = models.Sequential()
network.add(layers.Dense(256, activation='relu', input_shape=(28 * 28,)))
network.add(layers.Dropout(0.5))
network.add(layers.Dense(32, activation='relu'))
network.add(layers.Dropout(0.5))
network.add(layers.Dense(10, activation='softmax'))
network = models.Sequential()
network.add(layers.Dense(256, activation='relu', input_shape=(28 * 28,)))
network.add(layers.Dropout(0.3))
network.add(layers.Dense(32, activation='relu'))
network.add(layers.Dropout(0.3))
network.add(layers.Dense(10, activation='softmax'))
network.compile(loss='categorical_crossentropy', optimizer='rmsprop', metrics=['accuracy'])
plot_losses = TrainingPlot()
history = network.fit(partial_x_train, partial_y_train, epochs=50, batch_size=512, verbose=0,
                      validation_data=(x_val, y_val), callbacks=[plot_losses])
../_images/08 - Neural Networks_94_0.png

Batch Normalization

  • We’ve seen that scaling the input is important, but what if layer activations become very large?

    • Same problems, starting deeper in the network

  • Batch normalization: normalize the activations of the previous layer within each batch

    • Within a batch, set the mean activation close to 0 and the standard deviation close to 1

      • Across badges, use exponential moving average of batch-wise mean and variance

    • Allows deeper networks less prone to vanishing or exploding gradients

network = models.Sequential()
network.add(layers.Dense(512, activation='relu', input_shape=(28 * 28,)))
network.add(layers.BatchNormalization())
network.add(layers.Dropout(0.5))
network.add(layers.Dense(256, activation='relu'))
network.add(layers.BatchNormalization())
network.add(layers.Dropout(0.5))
network.add(layers.Dense(64, activation='relu'))
network.add(layers.BatchNormalization())
network.add(layers.Dropout(0.5))
network.add(layers.Dense(32, activation='relu'))
network.add(layers.BatchNormalization())
network.add(layers.Dropout(0.5))

network = models.Sequential()
network.add(layers.Dense(265, activation='relu', input_shape=(28 * 28,)))
network.add(layers.BatchNormalization())
network.add(layers.Dropout(0.5))
network.add(layers.Dense(64, activation='relu'))
network.add(layers.BatchNormalization())
network.add(layers.Dropout(0.5))
network.add(layers.Dense(32, activation='relu'))
network.add(layers.BatchNormalization())
network.add(layers.Dropout(0.5))
network.add(layers.Dense(10, activation='softmax'))
network.compile(loss='categorical_crossentropy', optimizer='rmsprop', metrics=['accuracy'])
plot_losses = TrainingPlot()
history = network.fit(partial_x_train, partial_y_train, epochs=50, batch_size=512, verbose=0,
                      validation_data=(x_val, y_val), callbacks=[plot_losses])
../_images/08 - Neural Networks_98_0.png

Tuning multiple hyperparameters

  • You can wrap Keras models as scikit-learn models and use any tuning technique

  • Keras also has built-in RandomSearch (and HyperBand and BayesianOptimization - see later)

def make_model(hp):
    m.add(Dense(units=hp.Int('units', min_value=32, max_value=512, step=32)))
    m.compile(optimizer=Adam(hp.Choice('learning rate', [1e-2, 1e-3, 1e-4])))
    return model
from tensorflow.keras.wrappers.scikit_learn import KerasClassifier
clf = KerasClassifier(make_model)
grid = GridSearchCV(clf, param_grid=param_grid, cv=3)

from kerastuner.tuners import RandomSearch
tuner = keras.RandomSearch(build_model, max_trials=5)

Summary

  • Neural architectures

  • Training neural nets

    • Forward pass: Tensor operations

    • Backward pass: Backpropagation

  • Neural network design:

    • Activation functions

    • Weight initialization

    • Optimizers

  • Neural networks in practice

  • Model selection

    • Early stopping

    • Memorization capacity and information bottleneck

    • L1/L2 regularization

    • Dropout

    • Batch normalization